예제 #1
0
        public Operator(string operatorName)
        {
            this._OpName = operatorName;
            this._Handle = OpMap.GetSymbolCreator(operatorName);

            var return_type = System.IntPtr.Zero;

            Logging.CHECK_EQ(NativeMethods.MXSymbolGetAtomicSymbolInfo(this._Handle,
                                                                       out var name,
                                                                       out var description,
                                                                       out var numArgs,
                                                                       out var argNames,
                                                                       out var argTypeInfos,
                                                                       out var argDescriptions,
                                                                       out var keyVarNumArgs,
                                                                       ref return_type), NativeMethods.OK);

            var argNamesArray = InteropHelper.ToPointerArray(argNames, numArgs);

            for (var i = 0; i < numArgs; ++i)
            {
                var pArgName = argNamesArray[i];
                this._ArgNames.Add(Marshal.PtrToStringAnsi(pArgName));
            }
        }
예제 #2
0
        public MXDataIterMap()
        {
            var r = NativeMethods.MXListDataIters(out var numDataIterCreators, out var dataIterCreators);

            Logging.CHECK_EQ(r, 0);


            this._DataIterCreators = new Dictionary <string, DataIterHandle>((int)numDataIterCreators);

            var array = InteropHelper.ToPointerArray(dataIterCreators, numDataIterCreators);

            for (var i = 0; i < numDataIterCreators; i++)
            {
                r = NativeMethods.MXDataIterGetIterInfo(array[i],
                                                        out var name,
                                                        out var description,
                                                        out var num_args,
                                                        out var arg_names2,
                                                        out var arg_type_infos2,
                                                        out var arg_descriptions2);

                Logging.CHECK_EQ(r, 0);

                var str = Marshal.PtrToStringAnsi(name);
                this._DataIterCreators.Add(str, array[i]);
            }
        }
예제 #3
0
        public void Forward(bool isTrain)
        {
            NativeMethods.MXExecutorForward(this.Handle, isTrain ? 1 : 0);
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(this.Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (var i = 0; i < outSize; ++i)
            {
                this.Outputs[i]?.Dispose();
                this.Outputs[i] = new NDArray(outArrayArray[i]);
            }
        }
예제 #4
0
        public IList <string> ListOutputs()
        {
            this.ThrowIfDisposed();

            NativeMethods.MXSymbolListOutputs(this.GetHandle(), out var size, out var sarry);
            var sarryArray = InteropHelper.ToPointerArray(sarry, size);
            var ret        = new string[size];

            for (var i = 0; i < size; i++)
            {
                ret[i] = Marshal.PtrToStringAnsi(sarryArray[i]);
            }

            return(ret);
        }
예제 #5
0
        public OpMap()
        {
            var r = NativeMethods.MXSymbolListAtomicSymbolCreators(out var numSymbolCreators, out var symbolCreators);

            Logging.CHECK_EQ(r, NativeMethods.OK);

            this._SymbolCreators = new Dictionary <string, AtomicSymbolCreator>((int)numSymbolCreators);

            var symbolCreatorsArray = InteropHelper.ToPointerArray(symbolCreators, numSymbolCreators);

            for (var i = 0; i < numSymbolCreators; i++)
            {
                var return_type = System.IntPtr.Zero;
                r = NativeMethods.MXSymbolGetAtomicSymbolInfo(symbolCreatorsArray[i],
                                                              out var name,
                                                              out var description,
                                                              out var numArgs,
                                                              out var argNames,
                                                              out var argTypeInfos,
                                                              out var argDescriptions,
                                                              out var nameBuilder,
                                                              ref return_type);
                Logging.CHECK_EQ(r, NativeMethods.OK);
                var str = Marshal.PtrToStringAnsi(name);
                this._SymbolCreators.Add(str, symbolCreatorsArray[i]);
            }

            r = NativeMethods.NNListAllOpNames(out var numOps, out var opNames);
            Logging.CHECK_EQ(r, NativeMethods.OK);

            this._OpHandles = new Dictionary <string, AtomicSymbolCreator>((int)numOps);

            var opNamesArray = InteropHelper.ToPointerArray(opNames, numOps);

            for (var i = 0; i < numOps; i++)
            {
                r = NativeMethods.NNGetOpHandle(opNamesArray[i], out var handle);
                Logging.CHECK_EQ(r, NativeMethods.OK);
                var str = Marshal.PtrToStringAnsi(opNamesArray[i]);
                this._OpHandles.Add(str, handle);
            }
        }
예제 #6
0
        public static IDictionary <string, NDArray> LoadToMap(string fileName)
        {
            var arrayMap = new SortedDictionary <string, NDArray>();

            Logging.CHECK_EQ(NativeMethods.MXNDArrayLoad(fileName,
                                                         out var outSize,
                                                         out var outArr,
                                                         out var outNameSize,
                                                         out var outNames), NativeMethods.OK);
            if (outNameSize > 0)
            {
                var array     = InteropHelper.ToPointerArray(outArr, outSize);
                var namearray = InteropHelper.ToPointerArray(outNames, outNameSize);

                Logging.CHECK_EQ(outNameSize, outSize);
                for (mx_uint i = 0; i < outSize; ++i)
                {
                    var name = Marshal.PtrToStringAnsi(namearray[i]);
                    arrayMap[name] = new NDArray(array[i]);
                }
            }

            return(arrayMap);
        }
예제 #7
0
        public Executor(Symbol symbol,
                        Context context,
                        IList <NDArray> argmentArrays,
                        IList <NDArray> gradientArrays,
                        IList <OpReqType> gradReqs,
                        IList <NDArray> auxiliaryArrays,
                        IDictionary <string, Context> groupToCtx,
                        Executor sharedExec)
        {
            if (symbol == null)
            {
                throw new ArgumentNullException(nameof(symbol));
            }
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argmentArrays == null)
            {
                throw new ArgumentNullException(nameof(argmentArrays));
            }
            if (gradientArrays == null)
            {
                throw new ArgumentNullException(nameof(gradientArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxiliaryArrays == null)
            {
                throw new ArgumentNullException(nameof(auxiliaryArrays));
            }
            if (groupToCtx == null)
            {
                throw new ArgumentNullException(nameof(groupToCtx));
            }

            this.ArgmentArrays   = argmentArrays;
            this.GradientArrays  = gradientArrays;
            this.AuxiliaryArrays = auxiliaryArrays;
            this._Symbol         = symbol;

            var argHandles   = argmentArrays.Select(array => array.GetHandle()).ToArray();
            var gradHandles  = gradientArrays.Select(array => array.GetHandle()).ToArray();
            var auxHandles   = auxiliaryArrays.Select(array => array.GetHandle()).ToArray();
            var gradReqsUint = gradReqs.Select(s => (mx_uint)s).ToArray();

            var mapKeys  = new string[groupToCtx.Count];
            var devTypes = new int[groupToCtx.Count];
            var devIds   = new int[groupToCtx.Count];
            var keys     = groupToCtx.Keys.ToArray();

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                mapKeys[index] = key;
                var value = groupToCtx[key];
                devTypes[index] = (int)value.GetDeviceType();
                devIds[index]   = value.GetDeviceId();
            }

            var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero;

            Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(),
                                                            (int)context.GetDeviceType(),
                                                            context.GetDeviceId(),
                                                            (uint)groupToCtx.Count,
                                                            mapKeys,
                                                            devTypes,
                                                            devIds,
                                                            (uint)argHandles.Length,
                                                            argHandles,
                                                            gradHandles,
                                                            gradReqsUint,
                                                            (uint)auxHandles.Length,
                                                            auxHandles,
                                                            sharedExecHandle,
                                                            out var handle), NativeMethods.OK);
            this.Handle = handle;

            this.Outputs = new List <NDArray>();
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(this.Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (mx_uint i = 0; i < outSize; ++i)
            {
                this.Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }