private static (IntPtr[], IntPtr[]) ParseHead(NDArrayList heads, NDArrayList head_grads) { IntPtr[] headHandles = null; IntPtr[] headGradHandles = null; headHandles = MxUtil.GetNDArrayHandles(heads); if (head_grads == null) { headGradHandles = new IntPtr[heads.Length]; for (var i = 0; i < headHandles.Length; i++) { headGradHandles[i] = IntPtr.Zero; } } else { if (heads.Length != head_grads.Length) { throw new ArgumentException("heads and head_grads must be lists of the same length"); } headGradHandles = MxUtil.GetNDArrayHandles(head_grads); } return(headHandles, headGradHandles); }
public static void MarkVariables(NDArrayList variables, NDArrayList gradients, OpGradReq grad_reqs = OpGradReq.Write) { var gradReqs = new int[variables.Length]; for (var i = 0; i < gradReqs.Length; i++) { gradReqs[i] = (int)OpGradReq.Write; } NativeMethods.MXAutogradMarkVariables(variables.Length, MxUtil.GetNDArrayHandles(variables), gradReqs, MxUtil.GetNDArrayHandles(gradients)); }
public NDArrayList Call(NDArrayList args) { NativeMethods.MXInvokeCachedOpEx(handle, args.Length, MxUtil.GetNDArrayHandles(args), out var num_outputs, out var outputs, out var out_stypes); var result = new NDArrayList(); for (var i = 0; i < num_outputs; i++) { result.Add(new NDArray(outputs[i]).ToSType((StorageStype)out_stypes[i])); } return(result.ToArray()); }
public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null, bool retain_graph = false, bool create_graph = true, bool train_mode = true) { var(head_handles, head_grads_handles) = ParseHead(heads, head_grads); //var grad_handles = new IntPtr[head_handles.Length]; //var grad_stypes = new int[head_handles.Length]; NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length, MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph), Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes); var result = new NDArrayList(); foreach (var item in grad_handles) { result.Add(new NDArray(item)); } return(result.ToArray()); }
internal static Executor BindExec(Symbol sym, Context ctx, Dictionary <string, Shape> input_shapes, string[] param_names, bool need_grad = false, Executor base_exec = null, NDArrayDict shared_data_arrays = null, Dictionary <string, DType> input_types = null, Logger logger = null) { var(arg_shape, _, aux_shape) = sym.InferShape(input_shapes); if (arg_shape == null) { throw new ArgumentNullException("arg_shape"); } if (input_types == null) { input_types = new Dictionary <string, DType>(); foreach (var item in input_shapes.Keys) { input_types.Add(item, DType.Float32); } } var(arg_types, _, aux_types) = sym.InferType(input_types); if (arg_types == null) { throw new ArgumentNullException("arg_types"); } var arg_arrays = new NDArrayList(); var aux_arrays = new NDArrayList(); var grad_arrays = need_grad ? new NDArrayDict() : null; var arg_names = sym.ListArguments(); var needGradSet = new List <string>(); if (!need_grad) { needGradSet = new List <string>(); } else { foreach (var item in arg_names) { if (!input_shapes.ContainsKey(item)) { needGradSet.Add(item); } } needGradSet = MxUtil.Set(needGradSet); } var grad_req = new Dictionary <string, OpGradReq>(); foreach (var item in arg_names) { if (needGradSet.Contains(item)) { grad_req.Add(item, OpGradReq.Write); } } for (var i = 0; i < arg_names.Count; i++) { var name = arg_names[i]; NDArray arg_arr = null; NDArray grad_arr = null; if (!param_names.Contains(name)) { if (shared_data_arrays != null && shared_data_arrays.Contains(name)) { arg_arr = shared_data_arrays[name]; if (np.prod(arg_arr.Shape.Data) >= np.prod(arg_shape[i].Data)) { if (arg_types[i].Name != arg_arr.DataType.Name) { throw new ArgumentException("arg_type and arg_arr datatype mismatch"); } arg_arr = arg_arr.Reshape(arg_shape[i]); } else { var logmsg = new StringBuilder(); logmsg.AppendFormat("bucketing: data \"{0}\" has a shape {1}", name, arg_shape[i]); logmsg.AppendFormat(", which is larger than already allocated "); logmsg.AppendFormat("shape {0}", arg_arr.Shape); logmsg.AppendFormat(". Need to re-allocate. Consider putting default_bucket_key " + "to be the bucket taking the largest input for better memory sharing."); Logger.Warning(logmsg.ToString()); arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); shared_data_arrays[name] = arg_arr; } } else { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (shared_data_arrays != null) { shared_data_arrays[name] = arg_arr; } } arg_arrays.Add(arg_arr); } else { if (base_exec == null) { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (needGradSet.Contains(name)) { grad_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); grad_arrays[name] = grad_arr; } else { arg_arr = base_exec.ArgmentDictionary()[name]; if (arg_arr.Shape != arg_shape[i]) { throw new ArgumentException("arg_arr.Shape != arg_shape[i]"); } if (arg_arr.DataType != arg_types[i]) { throw new ArgumentException("arg_arr.DataType != arg_types[i]"); } if (needGradSet.Contains(name)) { grad_arrays[name] = base_exec.GradientDictionary()[name]; } } arg_arrays.Add(arg_arr); } } } if (base_exec != null) { for (var i = 0; i < aux_shape.Length; i++) { var s = aux_shape[i]; var t = aux_types[i]; aux_arrays.Add(nd.Zeros(s, ctx, t)); } } else { foreach (var item in base_exec.AuxiliaryDictionary()) { aux_arrays.Add(item.Value); } } var executor = sym.Bind(ctx, arg_arrays, grad_arrays.Values.ToList(), grad_req.Values.ToList(), aux_arrays); return(executor); }
public CachedOp(Symbol sym, NDArrayDict flags) { handle = IntPtr.Zero; NativeMethods.MXCreateCachedOpEx(sym.GetHandle(), flags.Count, flags.Keys.ToArray(), MxUtil.GetNDArrayHandles(flags.Values.ToArray()), out handle); }