public MXDataIterMap() { var r = NativeMethods.MXListDataIters(out var numDataIterCreators, out var dataIterCreators); Logging.CHECK_EQ(r, 0); _DataIterCreators = new Dictionary <string, DataIterHandle>((int)numDataIterCreators); var array = InteropHelper.ToPointerArray(dataIterCreators, numDataIterCreators); for (var i = 0; i < numDataIterCreators; i++) { r = NativeMethods.MXDataIterGetIterInfo(array[i], out var name, out var description, out var num_args, out var arg_names2, out var arg_type_infos2, out var arg_descriptions2); Logging.CHECK_EQ(r, 0); var str = Marshal.PtrToStringAnsi(name); _DataIterCreators.Add(str, array[i]); } }
public Operator(string operatorName) { _OpName = operatorName; _Handle = OpMap.GetSymbolCreator(operatorName); var return_type = SymbolHandle.Zero; Logging.CHECK_EQ(NativeMethods.MXSymbolGetAtomicSymbolInfo(_Handle, out var name, out var description, out var numArgs, out var argNames, out var argTypeInfos, out var argDescriptions, out var keyVarNumArgs, ref return_type), NativeMethods.OK); var argNamesArray = InteropHelper.ToPointerArray(argNames, numArgs); for (var i = 0; i < numArgs; ++i) { var pArgName = argNamesArray[i]; _ArgNames.Add(Marshal.PtrToStringAnsi(pArgName)); } }
public Dictionary <string, Dictionary <string, string> > ListAttributeDict() { ThrowIfDisposed(); NativeMethods.MXSymbolListAuxiliaryStates(GetHandle(), out var size, out var sarry); var sarryArray = InteropHelper.ToPointerArray(sarry, size); Dictionary <string, Dictionary <string, string> > ret = new Dictionary <string, Dictionary <string, string> >(); for (var i = 0; i < size; i++) { string[] pair = Marshal.PtrToStringAnsi(sarryArray[i * 2]).Split('$'); string name = pair[0]; string key = pair[1]; string val = Marshal.PtrToStringAnsi(sarryArray[i * 2 + 1]); if (!ret.ContainsKey(name)) { ret.Add(name, new Dictionary <string, string>()); } ret[name][key] = val; } return(ret); }
public void Forward(bool isTrain) { NativeMethods.MXExecutorForward(Handle, isTrain ? 1 : 0); Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (var i = 0; i < outSize; ++i) { Outputs[i]?.Dispose(); Outputs[i] = new NDArray(outArrayArray[i]); } }
public IList <string> ListOutputs() { ThrowIfDisposed(); NativeMethods.MXSymbolListOutputs(GetHandle(), out var size, out var sarry); var sarryArray = InteropHelper.ToPointerArray(sarry, size); var ret = new string[size]; for (var i = 0; i < size; i++) { ret[i] = Marshal.PtrToStringAnsi(sarryArray[i]); } return(ret); }
public Executor(ExecutorHandle h, Context context, List <OpGradReq> gradReqs, Dictionary <string, Context> groupToCtx) { if (h == IntPtr.Zero) { throw new ArgumentException("Can not pass IntPtr.Zero", nameof(h)); } Outputs = new NDArrayList(); Handle = h; Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (uint i = 0; i < outSize; ++i) { Outputs.Add(new NDArray(outArrayArray[i])); } }
public OpMap() { var r = NativeMethods.MXSymbolListAtomicSymbolCreators(out var numSymbolCreators, out var symbolCreators); Logging.CHECK_EQ(r, NativeMethods.OK); _SymbolCreators = new Dictionary <string, AtomicSymbolCreator>((int)numSymbolCreators); var symbolCreatorsArray = InteropHelper.ToPointerArray(symbolCreators, numSymbolCreators); for (var i = 0; i < numSymbolCreators; i++) { var return_type = AtomicSymbolCreator.Zero; r = NativeMethods.MXSymbolGetAtomicSymbolInfo(symbolCreatorsArray[i], out var name, out var description, out var numArgs, out var argNames, out var argTypeInfos, out var argDescriptions, out var nameBuilder, ref return_type); Logging.CHECK_EQ(r, NativeMethods.OK); var str = Marshal.PtrToStringAnsi(name); _SymbolCreators.Add(str, symbolCreatorsArray[i]); } r = NativeMethods.NNListAllOpNames(out var numOps, out var opNames); Logging.CHECK_EQ(r, NativeMethods.OK); _OpHandles = new Dictionary <string, AtomicSymbolCreator>((int)numOps); var opNamesArray = InteropHelper.ToPointerArray(opNames, numOps); for (var i = 0; i < numOps; i++) { r = NativeMethods.NNGetOpHandle(opNamesArray[i], out var handle); Logging.CHECK_EQ(r, NativeMethods.OK); var str = Marshal.PtrToStringAnsi(opNamesArray[i]); _OpHandles.Add(str, handle); } }
public Dictionary <string, Dictionary <string, string> > AttrDict() { NativeMethods.MXSymbolListAttr(GetHandle(), out var out_size, out var sarray); var array = InteropHelper.ToPointerArray(sarray, out_size); List <string> pairs = array.Select(x => (Marshal.PtrToStringAnsi(x))).ToList(); var dict = new Dictionary <string, Dictionary <string, string> >(); var i = 0; while (i < out_size) { var keys = pairs[i].Split('$'); if (!dict.ContainsKey(keys[0])) { dict[keys[0]] = new Dictionary <string, string>(); } if ((i + 1) != pairs.Count) { dict[keys[0]][keys[1]] = pairs[i + 1]; } i = i + 2; } return(dict); }
public Executor(Symbol symbol, Context context, NDArrayList argmentArrays, NDArrayList gradientArrays, IList <OpGradReq> gradReqs, NDArrayList auxiliaryArrays, IDictionary <string, Context> groupToCtx, Executor sharedExec) { if (symbol == null) { throw new ArgumentNullException(nameof(symbol)); } if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argmentArrays == null) { throw new ArgumentNullException(nameof(argmentArrays)); } if (gradientArrays == null) { throw new ArgumentNullException(nameof(gradientArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxiliaryArrays == null) { throw new ArgumentNullException(nameof(auxiliaryArrays)); } if (groupToCtx == null) { throw new ArgumentNullException(nameof(groupToCtx)); } ArgmentArrays = argmentArrays; GradientArrays = gradientArrays; AuxiliaryArrays = auxiliaryArrays; _Symbol = symbol; var argHandles = argmentArrays.Select(array => array.GetHandle()).ToArray(); var gradHandles = gradientArrays.Select(array => array.GetHandle()).ToArray(); var auxHandles = auxiliaryArrays.Select(array => array.GetHandle()).ToArray(); var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray(); var mapKeys = new string[groupToCtx.Count]; var devTypes = new int[groupToCtx.Count]; var devIds = new int[groupToCtx.Count]; var keys = groupToCtx.Keys.ToArray(); for (var index = 0; index < keys.Length; index++) { var key = keys[index]; mapKeys[index] = key; var value = groupToCtx[key]; devTypes[index] = (int)value.GetDeviceType(); devIds[index] = value.GetDeviceId(); } var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero; Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(), (int)context.GetDeviceType(), context.GetDeviceId(), (uint)groupToCtx.Count, mapKeys, devTypes, devIds, (uint)argHandles.Length, argHandles, gradHandles, gradReqsUint, (uint)auxHandles.Length, auxHandles, sharedExecHandle, out var handle), NativeMethods.OK); Handle = handle; Outputs = new NDArrayList(); Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (uint i = 0; i < outSize; ++i) { Outputs.Add(new NDArray(outArrayArray[i])); } }