public void Invoke(NDArrayList outputs) { var paramKeys = new List <string>(); var paramValues = new List <string>(); foreach (var data in _Params) { paramKeys.Add(data.Key); paramValues.Add(data.Value); } var numInputs = _InputNdarrays.Count; var numOutputs = outputs.Length; var outputHandles = outputs.Select(s => s.GetHandle()).ToArray(); var outputsReceiver = IntPtr.Zero; GCHandle?gcHandle = null; try { if (outputs.Length > 0) { gcHandle = GCHandle.Alloc(outputHandles, GCHandleType.Pinned); outputsReceiver = gcHandle.Value.AddrOfPinnedObject(); } NDArrayHandle[] outputsReceivers = { outputsReceiver }; NativeMethods.MXImperativeInvoke(_Handle, numInputs, _InputNdarrays.ToArray(), ref numOutputs, ref outputsReceiver, paramKeys.Count, paramKeys.ToArray(), paramValues.ToArray()); if (outputs.Length > 0) { gcHandle?.Free(); return; } outputHandles = new NDArrayHandle[numOutputs]; Marshal.Copy(outputsReceiver, outputHandles, 0, numOutputs); foreach (var outputHandle in outputHandles) { outputs.Add(new NDArray(outputHandle)); } gcHandle?.Free(); } catch (Exception ex) { throw ex; } finally { gcHandle?.Free(); } }
public void Backward(NDArrayList headGrads) { if (headGrads == null) { throw new ArgumentNullException(nameof(headGrads)); } var tmp = headGrads.Select(d => d.GetHandle()).ToArray(); if (tmp.Length > 0) { NativeMethods.MXExecutorBackward(Handle, (uint)tmp.Length, tmp); } else { NativeMethods.MXExecutorBackward(Handle, 0, null); } }
public Executor(Symbol symbol, Context context, NDArrayList argmentArrays, NDArrayList gradientArrays, IList <OpGradReq> gradReqs, NDArrayList auxiliaryArrays, IDictionary <string, Context> groupToCtx, Executor sharedExec) { if (symbol == null) { throw new ArgumentNullException(nameof(symbol)); } if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argmentArrays == null) { throw new ArgumentNullException(nameof(argmentArrays)); } if (gradientArrays == null) { throw new ArgumentNullException(nameof(gradientArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxiliaryArrays == null) { throw new ArgumentNullException(nameof(auxiliaryArrays)); } if (groupToCtx == null) { throw new ArgumentNullException(nameof(groupToCtx)); } ArgmentArrays = argmentArrays; GradientArrays = gradientArrays; AuxiliaryArrays = auxiliaryArrays; _Symbol = symbol; var argHandles = argmentArrays.Select(array => array.GetHandle()).ToArray(); var gradHandles = gradientArrays.Select(array => array.GetHandle()).ToArray(); var auxHandles = auxiliaryArrays.Select(array => array.GetHandle()).ToArray(); var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray(); var mapKeys = new string[groupToCtx.Count]; var devTypes = new int[groupToCtx.Count]; var devIds = new int[groupToCtx.Count]; var keys = groupToCtx.Keys.ToArray(); for (var index = 0; index < keys.Length; index++) { var key = keys[index]; mapKeys[index] = key; var value = groupToCtx[key]; devTypes[index] = (int)value.GetDeviceType(); devIds[index] = value.GetDeviceId(); } var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero; Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(), (int)context.GetDeviceType(), context.GetDeviceId(), (uint)groupToCtx.Count, mapKeys, devTypes, devIds, (uint)argHandles.Length, argHandles, gradHandles, gradReqsUint, (uint)auxHandles.Length, auxHandles, sharedExecHandle, out var handle), NativeMethods.OK); Handle = handle; Outputs = new NDArrayList(); Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (uint i = 0; i < outSize; ++i) { Outputs.Add(new NDArray(outArrayArray[i])); } }
public static IntPtr[] GetNDArrayHandles(NDArrayList list) { return(list.Select(x => x.GetHandle()).ToArray()); }
public static NDArrayOrSymbol[] ToNDArrayOrSymbols(this NDArrayList source) { return(source.Select(x => new NDArrayOrSymbol(x)).ToArray()); }