public void InferExecutorArrays(Context context, NDArrayList argArrays, NDArrayList gradArrays, IList <OpGradReq> gradReqs, NDArrayList auxArrays, NDArrayDict argsMap, NDArrayDict argGradStore) { InferExecutorArrays(context, argArrays, gradArrays, gradReqs, auxArrays, argsMap, argGradStore, new Dictionary <string, OpGradReq>()); }
public void Backward(NDArrayList headGrads) { if (headGrads == null) { throw new ArgumentNullException(nameof(headGrads)); } var tmp = headGrads.Select(d => d.GetHandle()).ToArray(); if (tmp.Length > 0) { NativeMethods.MXExecutorBackward(Handle, (uint)tmp.Length, tmp); } else { NativeMethods.MXExecutorBackward(Handle, 0, null); } }
public Executor(ExecutorHandle h, Context context, List <OpGradReq> gradReqs, Dictionary <string, Context> groupToCtx) { if (h == IntPtr.Zero) { throw new ArgumentException("Can not pass IntPtr.Zero", nameof(h)); } Outputs = new NDArrayList(); Handle = h; Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (uint i = 0; i < outSize; ++i) { Outputs.Add(new NDArray(outArrayArray[i])); } }
public Executor Bind(Context context, NDArrayList argArrays, NDArrayList gradArrays, IList <OpGradReq> gradReqs, NDArrayList auxArrays, IDictionary <string, Context> groupToCtx, Executor sharedExec) { if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argArrays == null) { throw new ArgumentNullException(nameof(argArrays)); } if (gradArrays == null) { throw new ArgumentNullException(nameof(gradArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxArrays == null) { throw new ArgumentNullException(nameof(auxArrays)); } if (groupToCtx == null) { throw new ArgumentNullException(nameof(groupToCtx)); } return(new Executor(this, context, argArrays, gradArrays, gradReqs, auxArrays, groupToCtx, sharedExec)); }
private static NDArrayDict GetDictionary(IList <string> names, NDArrayList arrays) { var ret = new NDArrayDict(); var set = new HashSet <string>(); foreach (var s in names) { Logging.CHECK(!set.Contains(s), $"Duplicate names detected, {s}"); set.Add(s); } Logging.CHECK_EQ(set.Count, arrays.Length, "names size not equal to arrays size"); for (var i = 0; i < names.Count; ++i) { ret[names[i]] = arrays[i]; } return(ret); }
public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null, bool retain_graph = false, bool create_graph = true, bool train_mode = true) { var(head_handles, head_grads_handles) = ParseHead(heads, head_grads); //var grad_handles = new IntPtr[head_handles.Length]; //var grad_stypes = new int[head_handles.Length]; NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length, MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph), Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes); var result = new NDArrayList(); foreach (var item in grad_handles) { result.Add(new NDArray(item)); } return(result.ToArray()); }
public void UpdateMetric(EvalMetric metric, NDArrayList labels, bool pre_sliced = false) { var labels_slice = new NDArrayList(); var i = 0; train_execs.Zip(slices, (e, s) => { if (!pre_sliced) { foreach (var label in labels) { labels_slice.Add(label.Slice(s.Begin, s.End.Value)); } } else { labels_slice.Add(labels[i]); } metric.Update(labels_slice.ToArray(), e.Outputs.ToArray()); i++; return(true); }); }
public abstract void Backward(OpGradReq[] req, NDArrayList out_grad, NDArrayList in_data, NDArrayList out_data, NDArrayList in_grad, NDArrayList aux);
public Executor SimpleBind(Context ctx, Dictionary <string, OpGradReq> grad_req = null, Dictionary <string, DType> type_dict = null, Dictionary <string, StorageStype> stype_dict = null, Dictionary <string, Context> group2ctx = null, string[] shared_arg_names = null, Executor shared_exec = null, NDArrayDict shared_buffer = null, DataDesc[] kwargs = null) { int num_provided_arg_types = 0; string[] provided_arg_type_names = null; int[] provided_arg_type_data = null; if (type_dict != null) { provided_arg_type_names = type_dict.Keys.ToArray(); provided_arg_type_data = type_dict.Values.Select(x => x.Index).ToArray(); num_provided_arg_types = type_dict.Count; } int num_provided_arg_stypes = 0; string[] provided_arg_stype_names = null; int[] provided_arg_stype_data = null; if (stype_dict != null) { provided_arg_stype_names = stype_dict.Keys.ToArray(); provided_arg_stype_data = stype_dict.Values.Select(x => (int)x).ToArray(); num_provided_arg_stypes = stype_dict.Count; } List <int> provided_arg_shape_data = new List <int>(); List <int> provided_arg_shape_idx = new List <int>() { 0 }; List <string> provided_arg_shape_names = new List <string>(); foreach (var desc in kwargs) { provided_arg_shape_names.Add(desc.Name); provided_arg_shape_data.AddRange(desc.Shape.Data.Where(x => x > 0).ToList()); provided_arg_shape_idx.Add(provided_arg_shape_data.Count); } int provided_req_type_list_len = 0; string[] provided_grad_req_names = new string[0]; string[] provided_grad_req_types = new string[0]; if (grad_req != null) { provided_grad_req_names = grad_req.Keys.ToArray(); provided_grad_req_types = grad_req.Values.Select(x => Enum.GetName(x.GetType(), x).ToLower()).ToArray(); provided_req_type_list_len = grad_req.Count; } int num_ctx_map_keys = 0; string[] ctx_map_keys = new string[0]; int[] ctx_map_dev_types = new int[0]; int[] ctx_map_dev_ids = new int[0]; if (group2ctx != null) { ctx_map_keys = group2ctx.Keys.ToArray(); ctx_map_dev_types = group2ctx.Values.Select(x => (int)x.GetDeviceType()).ToArray(); ctx_map_dev_ids = group2ctx.Values.Select(x => x.GetDeviceId()).ToArray(); num_ctx_map_keys = group2ctx.Count; } string[] shared_arg_name_list = new string[0]; if (shared_arg_names != null) { shared_arg_name_list = shared_arg_names; } unsafe { int shared_start = -1; int shared_buffer_len = shared_start; string[] shared_buffer_names = null; IntPtr[] shared_buffer_handles = null; if (shared_buffer.Count > 0) { shared_buffer_len = shared_buffer.Count; shared_buffer_names = shared_buffer.Keys; shared_buffer_handles = shared_buffer.Values.Handles; } var shared_exec_handle = shared_exec != null ? shared_exec.Handle : new ExecutorHandle(); char ** updated_shared_buffer_names; SymbolHandle * updated_shared_buffer_handles; int num_in_args; SymbolHandle * in_arg_handles; SymbolHandle * arg_grad_handles; SymbolHandle * aux_state_handles; ExecutorHandle exe_handle; int num_aux_states; NativeMethods.MXExecutorSimpleBindEx(NativePtr, (int)ctx.GetDeviceType(), ctx.GetDeviceId(), num_ctx_map_keys, ctx_map_keys, ctx_map_dev_types, ctx_map_dev_ids, provided_req_type_list_len, provided_grad_req_names, provided_grad_req_types, provided_arg_shape_names.Count, provided_arg_shape_names.ToArray(), provided_arg_shape_data.ToArray(), provided_arg_shape_idx.ToArray(), num_provided_arg_types, provided_arg_type_names, provided_arg_type_data, num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stype_data, shared_arg_name_list.Length, shared_arg_name_list, &shared_buffer_len, shared_buffer_names, shared_buffer_handles, &updated_shared_buffer_names, &updated_shared_buffer_handles, &num_in_args, &in_arg_handles, &arg_grad_handles, &num_aux_states, &aux_state_handles, shared_exec_handle, &exe_handle); if (shared_buffer.Count > 0) { int l = shared_buffer_len; for (int i = 0; i < l; i++) { string k = new string(updated_shared_buffer_names[i]); NDArray v = new NDArray(updated_shared_buffer_handles[i]); shared_buffer[k] = v; } } NDArrayList arg_arrays = new NDArrayList(); for (int i = 0; i < num_in_args; i++) { arg_arrays.Add(new NDArray(in_arg_handles[i])); } NDArrayList grad_arrays = new NDArrayList(); for (int i = 0; i < num_in_args; i++) { if (arg_grad_handles[i] == IntPtr.Zero) { continue; } grad_arrays.Add(new NDArray(arg_grad_handles[i])); } NDArrayList aux_arrays = new NDArrayList(); for (int i = 0; i < num_aux_states; i++) { aux_arrays.Add(new NDArray(aux_state_handles[i])); } var executor = new Executor(exe_handle, ctx, grad_req.Values.ToList(), group2ctx); executor.ArgmentArrays = arg_arrays; executor.GradientArrays = grad_arrays; executor.AuxiliaryArrays = aux_arrays; executor._Symbol = this; return(executor); } }
public void InferExecutorArrays(Context context, NDArrayList argArrays, NDArrayList gradArrays, IList <OpGradReq> gradReqs, NDArrayList auxArrays, NDArrayDict argsMap, NDArrayDict argGradStore, IDictionary <string, OpGradReq> gradReqType, NDArrayDict auxMap) { if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argArrays == null) { throw new ArgumentNullException(nameof(argArrays)); } if (gradArrays == null) { throw new ArgumentNullException(nameof(gradArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxArrays == null) { throw new ArgumentNullException(nameof(auxArrays)); } if (argsMap == null) { throw new ArgumentNullException(nameof(argsMap)); } if (argGradStore == null) { throw new ArgumentNullException(nameof(argGradStore)); } if (gradReqType == null) { throw new ArgumentNullException(nameof(gradReqType)); } if (auxMap == null) { throw new ArgumentNullException(nameof(auxMap)); } ThrowIfDisposed(); var argNameList = ListArguments(); var argShapes = new Dictionary <string, Shape>(); foreach (var argName in argNameList) { if (argsMap[argName] != null) { argShapes[argName] = argsMap[argName].Shape; } } var(inShapes, auxShapes, outShapes) = InferShape(argShapes); for (var i = 0; i < inShapes.Length; ++i) { var shape = inShapes[i]; var argName = argNameList[i]; if (argsMap[argName] != null) { argArrays.Add(argsMap[argName]); } else { argArrays.Add(new NDArray(shape, false)); //NDArray.SampleGaussian(0, 1, argArrays.Last()); var argArr = argArrays.Last(); nd.Random.Uniform(0, 1, argArr.Shape).CopyTo(argArr); } if (argGradStore[argName] != null) { gradArrays.Add(argGradStore[argName]); } else { gradArrays.Add(new NDArray(shape, false)); } if (gradReqType.TryGetValue(argName, out var value3)) { gradReqs.Add(value3); } else if (argName.LastIndexOf("data", StringComparison.InvariantCulture) == argName.Length - 4 || argName.LastIndexOf("label", StringComparison.InvariantCulture) == argName.Length - 5) { gradReqs.Add(OpGradReq.Null); } else { gradReqs.Add(OpGradReq.Write); } } var auxNameList = ListAuxiliaryStates(); for (var i = 0; i < auxShapes.Length; ++i) { var shape = auxShapes[i]; var auxName = auxNameList[i]; if (auxMap[auxName] != null) { auxArrays.Add(auxMap[auxName]); } else { auxArrays.Add(new NDArray(shape, false)); var aux = auxArrays.Last(); //NDArray.SampleGaussian(0, 1, auxArrays.Last()); nd.Random.Uniform(0, 1, aux.Shape).CopyTo(aux); } } }
public static NDArrayOrSymbol[] ToNDArrayOrSymbols(this NDArrayList source) { return(source.Select(x => new NDArrayOrSymbol(x)).ToArray()); }
internal static void LoadData(DataBatch batch, NDArrayList targets) { LoadGeneral(batch.Data, targets); }
public NDArrayOrSymbol(params NDArray[] x) { IsNDArray = true; IsSymbol = false; ndx = x; }
public void UpdateMetric(EvalMetric eval_metric, NDArrayList labels, bool pre_sliced = false) { curr_execgrp.UpdateMetric(eval_metric, labels, pre_sliced); }
public Executor(Symbol symbol, Context context, NDArrayList argmentArrays, NDArrayList gradientArrays, IList <OpGradReq> gradReqs, NDArrayList auxiliaryArrays, IDictionary <string, Context> groupToCtx, Executor sharedExec) { if (symbol == null) { throw new ArgumentNullException(nameof(symbol)); } if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argmentArrays == null) { throw new ArgumentNullException(nameof(argmentArrays)); } if (gradientArrays == null) { throw new ArgumentNullException(nameof(gradientArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxiliaryArrays == null) { throw new ArgumentNullException(nameof(auxiliaryArrays)); } if (groupToCtx == null) { throw new ArgumentNullException(nameof(groupToCtx)); } ArgmentArrays = argmentArrays; GradientArrays = gradientArrays; AuxiliaryArrays = auxiliaryArrays; _Symbol = symbol; var argHandles = argmentArrays.Select(array => array.GetHandle()).ToArray(); var gradHandles = gradientArrays.Select(array => array.GetHandle()).ToArray(); var auxHandles = auxiliaryArrays.Select(array => array.GetHandle()).ToArray(); var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray(); var mapKeys = new string[groupToCtx.Count]; var devTypes = new int[groupToCtx.Count]; var devIds = new int[groupToCtx.Count]; var keys = groupToCtx.Keys.ToArray(); for (var index = 0; index < keys.Length; index++) { var key = keys[index]; mapKeys[index] = key; var value = groupToCtx[key]; devTypes[index] = (int)value.GetDeviceType(); devIds[index] = value.GetDeviceId(); } var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero; Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(), (int)context.GetDeviceType(), context.GetDeviceId(), (uint)groupToCtx.Count, mapKeys, devTypes, devIds, (uint)argHandles.Length, argHandles, gradHandles, gradReqsUint, (uint)auxHandles.Length, auxHandles, sharedExecHandle, out var handle), NativeMethods.OK); Handle = handle; Outputs = new NDArrayList(); Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (uint i = 0; i < outSize; ++i) { Outputs.Add(new NDArray(outArrayArray[i])); } }
public abstract NDArray Forward(bool is_train, OpGradReq[] req, NDArrayList in_data, NDArrayList out_data, NDArrayList aux);
internal static void LoadLabel(DataBatch batch, NDArrayList targets) { LoadGeneral(batch.Label, targets); }
internal static Executor BindExec(Symbol sym, Context ctx, Dictionary <string, Shape> input_shapes, string[] param_names, bool need_grad = false, Executor base_exec = null, NDArrayDict shared_data_arrays = null, Dictionary <string, DType> input_types = null, Logger logger = null) { var(arg_shape, _, aux_shape) = sym.InferShape(input_shapes); if (arg_shape == null) { throw new ArgumentNullException("arg_shape"); } if (input_types == null) { input_types = new Dictionary <string, DType>(); foreach (var item in input_shapes.Keys) { input_types.Add(item, DType.Float32); } } var(arg_types, _, aux_types) = sym.InferType(input_types); if (arg_types == null) { throw new ArgumentNullException("arg_types"); } var arg_arrays = new NDArrayList(); var aux_arrays = new NDArrayList(); var grad_arrays = need_grad ? new NDArrayDict() : null; var arg_names = sym.ListArguments(); var needGradSet = new List <string>(); if (!need_grad) { needGradSet = new List <string>(); } else { foreach (var item in arg_names) { if (!input_shapes.ContainsKey(item)) { needGradSet.Add(item); } } needGradSet = MxUtil.Set(needGradSet); } var grad_req = new Dictionary <string, OpGradReq>(); foreach (var item in arg_names) { if (needGradSet.Contains(item)) { grad_req.Add(item, OpGradReq.Write); } } for (var i = 0; i < arg_names.Count; i++) { var name = arg_names[i]; NDArray arg_arr = null; NDArray grad_arr = null; if (!param_names.Contains(name)) { if (shared_data_arrays != null && shared_data_arrays.Contains(name)) { arg_arr = shared_data_arrays[name]; if (np.prod(arg_arr.Shape.Data) >= np.prod(arg_shape[i].Data)) { if (arg_types[i].Name != arg_arr.DataType.Name) { throw new ArgumentException("arg_type and arg_arr datatype mismatch"); } arg_arr = arg_arr.Reshape(arg_shape[i]); } else { var logmsg = new StringBuilder(); logmsg.AppendFormat("bucketing: data \"{0}\" has a shape {1}", name, arg_shape[i]); logmsg.AppendFormat(", which is larger than already allocated "); logmsg.AppendFormat("shape {0}", arg_arr.Shape); logmsg.AppendFormat(". Need to re-allocate. Consider putting default_bucket_key " + "to be the bucket taking the largest input for better memory sharing."); Logger.Warning(logmsg.ToString()); arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); shared_data_arrays[name] = arg_arr; } } else { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (shared_data_arrays != null) { shared_data_arrays[name] = arg_arr; } } arg_arrays.Add(arg_arr); } else { if (base_exec == null) { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (needGradSet.Contains(name)) { grad_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); grad_arrays[name] = grad_arr; } else { arg_arr = base_exec.ArgmentDictionary()[name]; if (arg_arr.Shape != arg_shape[i]) { throw new ArgumentException("arg_arr.Shape != arg_shape[i]"); } if (arg_arr.DataType != arg_types[i]) { throw new ArgumentException("arg_arr.DataType != arg_types[i]"); } if (needGradSet.Contains(name)) { grad_arrays[name] = base_exec.GradientDictionary()[name]; } } arg_arrays.Add(arg_arr); } } } if (base_exec != null) { for (var i = 0; i < aux_shape.Length; i++) { var s = aux_shape[i]; var t = aux_types[i]; aux_arrays.Add(nd.Zeros(s, ctx, t)); } } else { foreach (var item in base_exec.AuxiliaryDictionary()) { aux_arrays.Add(item.Value); } } var executor = sym.Bind(ctx, arg_arrays, grad_arrays.Values.ToList(), grad_req.Values.ToList(), aux_arrays); return(executor); }
public static IntPtr[] GetNDArrayHandles(NDArrayList list) { return(list.Select(x => x.GetHandle()).ToArray()); }