private static (IntPtr[], IntPtr[]) ParseHead(NDArrayList heads, NDArrayList head_grads) { IntPtr[] headHandles = null; IntPtr[] headGradHandles = null; headHandles = MxUtil.GetNDArrayHandles(heads); if (head_grads == null) { headGradHandles = new IntPtr[heads.Length]; for (var i = 0; i < headHandles.Length; i++) { headGradHandles[i] = IntPtr.Zero; } } else { if (heads.Length != head_grads.Length) { throw new ArgumentException("heads and head_grads must be lists of the same length"); } headGradHandles = MxUtil.GetNDArrayHandles(head_grads); } return(headHandles, headGradHandles); }
public static void MarkVariables(NDArrayList variables, NDArrayList gradients, OpGradReq grad_reqs = OpGradReq.Write) { var gradReqs = new int[variables.Length]; for (var i = 0; i < gradReqs.Length; i++) { gradReqs[i] = (int)OpGradReq.Write; } NativeMethods.MXAutogradMarkVariables(variables.Length, MxUtil.GetNDArrayHandles(variables), gradReqs, MxUtil.GetNDArrayHandles(gradients)); }
public NDArrayList Call(NDArrayList args) { NativeMethods.MXInvokeCachedOpEx(handle, args.Length, MxUtil.GetNDArrayHandles(args), out var num_outputs, out var outputs, out var out_stypes); var result = new NDArrayList(); for (var i = 0; i < num_outputs; i++) { result.Add(new NDArray(outputs[i]).ToSType((StorageStype)out_stypes[i])); } return(result.ToArray()); }
public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null, bool retain_graph = false, bool create_graph = true, bool train_mode = true) { var(head_handles, head_grads_handles) = ParseHead(heads, head_grads); //var grad_handles = new IntPtr[head_handles.Length]; //var grad_stypes = new int[head_handles.Length]; NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length, MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph), Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes); var result = new NDArrayList(); foreach (var item in grad_handles) { result.Add(new NDArray(item)); } return(result.ToArray()); }
public CachedOp(Symbol sym, NDArrayDict flags) { handle = IntPtr.Zero; NativeMethods.MXCreateCachedOpEx(sym.GetHandle(), flags.Count, flags.Keys.ToArray(), MxUtil.GetNDArrayHandles(flags.Values.ToArray()), out handle); }