Esempio n. 1
0
        public MXDataIterMap()
        {
            var r = NativeMethods.MXListDataIters(out var numDataIterCreators, out var dataIterCreators);

            Logging.CHECK_EQ(r, 0);


            _DataIterCreators = new Dictionary <string, DataIterHandle>((int)numDataIterCreators);

            var array = InteropHelper.ToPointerArray(dataIterCreators, numDataIterCreators);

            for (var i = 0; i < numDataIterCreators; i++)
            {
                r = NativeMethods.MXDataIterGetIterInfo(array[i],
                                                        out var name,
                                                        out var description,
                                                        out var num_args,
                                                        out var arg_names2,
                                                        out var arg_type_infos2,
                                                        out var arg_descriptions2);

                Logging.CHECK_EQ(r, 0);

                var str = Marshal.PtrToStringAnsi(name);
                _DataIterCreators.Add(str, array[i]);
            }
        }
Esempio n. 2
0
        public Operator(string operatorName)
        {
            _OpName = operatorName;
            _Handle = OpMap.GetSymbolCreator(operatorName);

            var return_type = SymbolHandle.Zero;

            Logging.CHECK_EQ(NativeMethods.MXSymbolGetAtomicSymbolInfo(_Handle,
                                                                       out var name,
                                                                       out var description,
                                                                       out var numArgs,
                                                                       out var argNames,
                                                                       out var argTypeInfos,
                                                                       out var argDescriptions,
                                                                       out var keyVarNumArgs,
                                                                       ref return_type), NativeMethods.OK);

            var argNamesArray = InteropHelper.ToPointerArray(argNames, numArgs);

            for (var i = 0; i < numArgs; ++i)
            {
                var pArgName = argNamesArray[i];
                _ArgNames.Add(Marshal.PtrToStringAnsi(pArgName));
            }
        }
Esempio n. 3
0
        public Dictionary <string, Dictionary <string, string> > ListAttributeDict()
        {
            ThrowIfDisposed();

            NativeMethods.MXSymbolListAuxiliaryStates(GetHandle(), out var size, out var sarry);
            var sarryArray = InteropHelper.ToPointerArray(sarry, size);

            Dictionary <string, Dictionary <string, string> > ret = new Dictionary <string, Dictionary <string, string> >();

            for (var i = 0; i < size; i++)
            {
                string[] pair = Marshal.PtrToStringAnsi(sarryArray[i * 2]).Split('$');
                string   name = pair[0];
                string   key  = pair[1];
                string   val  = Marshal.PtrToStringAnsi(sarryArray[i * 2 + 1]);
                if (!ret.ContainsKey(name))
                {
                    ret.Add(name, new Dictionary <string, string>());
                }

                ret[name][key] = val;
            }

            return(ret);
        }
Esempio n. 4
0
        public void Forward(bool isTrain)
        {
            NativeMethods.MXExecutorForward(Handle, isTrain ? 1 : 0);
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (var i = 0; i < outSize; ++i)
            {
                Outputs[i]?.Dispose();
                Outputs[i] = new NDArray(outArrayArray[i]);
            }
        }
Esempio n. 5
0
        public IList <string> ListOutputs()
        {
            ThrowIfDisposed();

            NativeMethods.MXSymbolListOutputs(GetHandle(), out var size, out var sarry);
            var sarryArray = InteropHelper.ToPointerArray(sarry, size);
            var ret        = new string[size];

            for (var i = 0; i < size; i++)
            {
                ret[i] = Marshal.PtrToStringAnsi(sarryArray[i]);
            }

            return(ret);
        }
Esempio n. 6
0
        public Executor(ExecutorHandle h, Context context,
                        List <OpGradReq> gradReqs,
                        Dictionary <string, Context> groupToCtx)
        {
            if (h == IntPtr.Zero)
            {
                throw new ArgumentException("Can not pass IntPtr.Zero", nameof(h));
            }
            Outputs = new NDArrayList();
            Handle  = h;
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (uint i = 0; i < outSize; ++i)
            {
                Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }
Esempio n. 7
0
        public OpMap()
        {
            var r = NativeMethods.MXSymbolListAtomicSymbolCreators(out var numSymbolCreators, out var symbolCreators);

            Logging.CHECK_EQ(r, NativeMethods.OK);

            _SymbolCreators = new Dictionary <string, AtomicSymbolCreator>((int)numSymbolCreators);

            var symbolCreatorsArray = InteropHelper.ToPointerArray(symbolCreators, numSymbolCreators);

            for (var i = 0; i < numSymbolCreators; i++)
            {
                var return_type = AtomicSymbolCreator.Zero;
                r = NativeMethods.MXSymbolGetAtomicSymbolInfo(symbolCreatorsArray[i],
                                                              out var name,
                                                              out var description,
                                                              out var numArgs,
                                                              out var argNames,
                                                              out var argTypeInfos,
                                                              out var argDescriptions,
                                                              out var nameBuilder,
                                                              ref return_type);
                Logging.CHECK_EQ(r, NativeMethods.OK);
                var str = Marshal.PtrToStringAnsi(name);
                _SymbolCreators.Add(str, symbolCreatorsArray[i]);
            }

            r = NativeMethods.NNListAllOpNames(out var numOps, out var opNames);
            Logging.CHECK_EQ(r, NativeMethods.OK);

            _OpHandles = new Dictionary <string, AtomicSymbolCreator>((int)numOps);

            var opNamesArray = InteropHelper.ToPointerArray(opNames, numOps);

            for (var i = 0; i < numOps; i++)
            {
                r = NativeMethods.NNGetOpHandle(opNamesArray[i], out var handle);
                Logging.CHECK_EQ(r, NativeMethods.OK);
                var str = Marshal.PtrToStringAnsi(opNamesArray[i]);
                _OpHandles.Add(str, handle);
            }
        }
Esempio n. 8
0
        public Dictionary <string, Dictionary <string, string> > AttrDict()
        {
            NativeMethods.MXSymbolListAttr(GetHandle(), out var out_size, out var sarray);
            var           array = InteropHelper.ToPointerArray(sarray, out_size);
            List <string> pairs = array.Select(x => (Marshal.PtrToStringAnsi(x))).ToList();
            var           dict  = new Dictionary <string, Dictionary <string, string> >();
            var           i     = 0;

            while (i < out_size)
            {
                var keys = pairs[i].Split('$');
                if (!dict.ContainsKey(keys[0]))
                {
                    dict[keys[0]] = new Dictionary <string, string>();
                }
                if ((i + 1) != pairs.Count)
                {
                    dict[keys[0]][keys[1]] = pairs[i + 1];
                }
                i = i + 2;
            }

            return(dict);
        }
Esempio n. 9
0
        public Executor(Symbol symbol,
                        Context context,
                        NDArrayList argmentArrays,
                        NDArrayList gradientArrays,
                        IList <OpGradReq> gradReqs,
                        NDArrayList auxiliaryArrays,
                        IDictionary <string, Context> groupToCtx,
                        Executor sharedExec)
        {
            if (symbol == null)
            {
                throw new ArgumentNullException(nameof(symbol));
            }
            if (context == null)
            {
                throw new ArgumentNullException(nameof(context));
            }
            if (argmentArrays == null)
            {
                throw new ArgumentNullException(nameof(argmentArrays));
            }
            if (gradientArrays == null)
            {
                throw new ArgumentNullException(nameof(gradientArrays));
            }
            if (gradReqs == null)
            {
                throw new ArgumentNullException(nameof(gradReqs));
            }
            if (auxiliaryArrays == null)
            {
                throw new ArgumentNullException(nameof(auxiliaryArrays));
            }
            if (groupToCtx == null)
            {
                throw new ArgumentNullException(nameof(groupToCtx));
            }

            ArgmentArrays   = argmentArrays;
            GradientArrays  = gradientArrays;
            AuxiliaryArrays = auxiliaryArrays;
            _Symbol         = symbol;

            var argHandles   = argmentArrays.Select(array => array.GetHandle()).ToArray();
            var gradHandles  = gradientArrays.Select(array => array.GetHandle()).ToArray();
            var auxHandles   = auxiliaryArrays.Select(array => array.GetHandle()).ToArray();
            var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray();

            var mapKeys  = new string[groupToCtx.Count];
            var devTypes = new int[groupToCtx.Count];
            var devIds   = new int[groupToCtx.Count];
            var keys     = groupToCtx.Keys.ToArray();

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                mapKeys[index] = key;
                var value = groupToCtx[key];
                devTypes[index] = (int)value.GetDeviceType();
                devIds[index]   = value.GetDeviceId();
            }

            var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero;

            Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(),
                                                            (int)context.GetDeviceType(),
                                                            context.GetDeviceId(),
                                                            (uint)groupToCtx.Count,
                                                            mapKeys,
                                                            devTypes,
                                                            devIds,
                                                            (uint)argHandles.Length,
                                                            argHandles,
                                                            gradHandles,
                                                            gradReqsUint,
                                                            (uint)auxHandles.Length,
                                                            auxHandles,
                                                            sharedExecHandle,
                                                            out var handle), NativeMethods.OK);
            Handle = handle;

            Outputs = new NDArrayList();
            Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0);
            var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize);

            for (uint i = 0; i < outSize; ++i)
            {
                Outputs.Add(new NDArray(outArrayArray[i]));
            }
        }