Ejemplo n.º 1
0
        public void Invoke(NDArrayList outputs)
        {
            var paramKeys   = new List <string>();
            var paramValues = new List <string>();

            foreach (var data in _Params)
            {
                paramKeys.Add(data.Key);
                paramValues.Add(data.Value);
            }

            var numInputs  = _InputNdarrays.Count;
            var numOutputs = outputs.Length;

            var      outputHandles   = outputs.Select(s => s.GetHandle()).ToArray();
            var      outputsReceiver = IntPtr.Zero;
            GCHandle?gcHandle        = null;

            try
            {
                if (outputs.Length > 0)
                {
                    gcHandle        = GCHandle.Alloc(outputHandles, GCHandleType.Pinned);
                    outputsReceiver = gcHandle.Value.AddrOfPinnedObject();
                }

                NDArrayHandle[] outputsReceivers = { outputsReceiver };

                NativeMethods.MXImperativeInvoke(_Handle, numInputs, _InputNdarrays.ToArray(), ref numOutputs,
                                                 ref outputsReceiver,
                                                 paramKeys.Count, paramKeys.ToArray(), paramValues.ToArray());

                if (outputs.Length > 0)
                {
                    gcHandle?.Free();
                    return;
                }

                outputHandles = new NDArrayHandle[numOutputs];

                Marshal.Copy(outputsReceiver, outputHandles, 0, numOutputs);

                foreach (var outputHandle in outputHandles)
                {
                    outputs.Add(new NDArray(outputHandle));
                }

                gcHandle?.Free();
            }
            catch (Exception ex)
            {
                throw ex;
            }
            finally
            {
                gcHandle?.Free();
            }
        }
Ejemplo n.º 2
0
        public void Backward(NDArrayList headGrads)
        {
            if (headGrads == null)
            {
                throw new ArgumentNullException(nameof(headGrads));
            }

            var tmp = headGrads.Select(d => d.GetHandle()).ToArray();

            if (tmp.Length > 0)
            {
                NativeMethods.MXExecutorBackward(Handle, (uint)tmp.Length, tmp);
            }
            else
            {
                NativeMethods.MXExecutorBackward(Handle, 0, null);
            }
        }
Ejemplo n.º 3
0
        public Executor(Symbol symbol,
                        Context context,
                        NDArrayList argmentArrays,
                        NDArrayList gradientArrays,
                        IList <OpGradReq> gradReqs,
                        NDArrayList auxiliaryArrays,
                        IDictionary <string, Context> groupToCtx,
                        Executor sharedExec)
        {
            if (symbol == null)
            {
                throw new ArgumentNullException(nameof(symbol));
            }
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argmentArrays == null)
            {
                throw new ArgumentNullException(nameof(argmentArrays));
            }
            if (gradientArrays == null)
            {
                throw new ArgumentNullException(nameof(gradientArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxiliaryArrays == null)
            {
                throw new ArgumentNullException(nameof(auxiliaryArrays));
            }
            if (groupToCtx == null)
            {
                throw new ArgumentNullException(nameof(groupToCtx));
            }

            ArgmentArrays   = argmentArrays;
            GradientArrays  = gradientArrays;
            AuxiliaryArrays = auxiliaryArrays;
            _Symbol         = symbol;

            var argHandles   = argmentArrays.Select(array => array.GetHandle()).ToArray();
            var gradHandles  = gradientArrays.Select(array => array.GetHandle()).ToArray();
            var auxHandles   = auxiliaryArrays.Select(array => array.GetHandle()).ToArray();
            var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray();

            var mapKeys  = new string[groupToCtx.Count];
            var devTypes = new int[groupToCtx.Count];
            var devIds   = new int[groupToCtx.Count];
            var keys     = groupToCtx.Keys.ToArray();

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                mapKeys[index] = key;
                var value = groupToCtx[key];
                devTypes[index] = (int)value.GetDeviceType();
                devIds[index]   = value.GetDeviceId();
            }

            var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero;

            Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(),
                                                            (int)context.GetDeviceType(),
                                                            context.GetDeviceId(),
                                                            (uint)groupToCtx.Count,
                                                            mapKeys,
                                                            devTypes,
                                                            devIds,
                                                            (uint)argHandles.Length,
                                                            argHandles,
                                                            gradHandles,
                                                            gradReqsUint,
                                                            (uint)auxHandles.Length,
                                                            auxHandles,
                                                            sharedExecHandle,
                                                            out var handle), NativeMethods.OK);
            Handle = handle;

            Outputs = new NDArrayList();
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (uint i = 0; i < outSize; ++i)
            {
                Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }
Ejemplo n.º 4
0
 public static IntPtr[] GetNDArrayHandles(NDArrayList list)
 {
     return(list.Select(x => x.GetHandle()).ToArray());
 }
Ejemplo n.º 5
0
 public static NDArrayOrSymbol[] ToNDArrayOrSymbols(this NDArrayList source)
 {
     return(source.Select(x => new NDArrayOrSymbol(x)).ToArray());
 }