Esempio n. 1
0
 public IList <mx_uint> GetShape()
 {
     NativeMethods.MXNDArrayGetShape(this.NativePtr, out var outDim, out var outData);
     return(InteropHelper.ToUInt32Array(outData, outDim));
 }
Esempio n. 2
0
        public Executor(Symbol symbol,
                        Context context,
                        IList <NDArray> argmentArrays,
                        IList <NDArray> gradientArrays,
                        IList <OpReqType> gradReqs,
                        IList <NDArray> auxiliaryArrays,
                        IDictionary <string, Context> groupToCtx,
                        Executor sharedExec)
        {
            if (symbol == null)
            {
                throw new ArgumentNullException(nameof(symbol));
            }
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argmentArrays == null)
            {
                throw new ArgumentNullException(nameof(argmentArrays));
            }
            if (gradientArrays == null)
            {
                throw new ArgumentNullException(nameof(gradientArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxiliaryArrays == null)
            {
                throw new ArgumentNullException(nameof(auxiliaryArrays));
            }
            if (groupToCtx == null)
            {
                throw new ArgumentNullException(nameof(groupToCtx));
            }

            this.ArgmentArrays   = argmentArrays;
            this.GradientArrays  = gradientArrays;
            this.AuxiliaryArrays = auxiliaryArrays;
            this._Symbol         = symbol;

            var argHandles   = argmentArrays.Select(array => array.GetHandle()).ToArray();
            var gradHandles  = gradientArrays.Select(array => array.GetHandle()).ToArray();
            var auxHandles   = auxiliaryArrays.Select(array => array.GetHandle()).ToArray();
            var gradReqsUint = gradReqs.Select(s => (mx_uint)s).ToArray();

            var mapKeys  = new string[groupToCtx.Count];
            var devTypes = new int[groupToCtx.Count];
            var devIds   = new int[groupToCtx.Count];
            var keys     = groupToCtx.Keys.ToArray();

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                mapKeys[index] = key;
                var value = groupToCtx[key];
                devTypes[index] = (int)value.GetDeviceType();
                devIds[index]   = value.GetDeviceId();
            }

            var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero;

            Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(),
                                                            (int)context.GetDeviceType(),
                                                            context.GetDeviceId(),
                                                            (uint)groupToCtx.Count,
                                                            mapKeys,
                                                            devTypes,
                                                            devIds,
                                                            (uint)argHandles.Length,
                                                            argHandles,
                                                            gradHandles,
                                                            gradReqsUint,
                                                            (uint)auxHandles.Length,
                                                            auxHandles,
                                                            sharedExecHandle,
                                                            out var handle), NativeMethods.OK);
            this.Handle = handle;

            this.Outputs = new List <NDArray>();
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(this.Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (mx_uint i = 0; i < outSize; ++i)
            {
                this.Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }