public void Invoke(NDArrayList outputs) { var paramKeys = new List <string>(); var paramValues = new List <string>(); foreach (var data in _Params) { paramKeys.Add(data.Key); paramValues.Add(data.Value); } var numInputs = _InputNdarrays.Count; var numOutputs = outputs.Length; var outputHandles = outputs.Select(s => s.GetHandle()).ToArray(); var outputsReceiver = IntPtr.Zero; GCHandle?gcHandle = null; try { if (outputs.Length > 0) { gcHandle = GCHandle.Alloc(outputHandles, GCHandleType.Pinned); outputsReceiver = gcHandle.Value.AddrOfPinnedObject(); } NDArrayHandle[] outputsReceivers = { outputsReceiver }; NativeMethods.MXImperativeInvoke(_Handle, numInputs, _InputNdarrays.ToArray(), ref numOutputs, ref outputsReceiver, paramKeys.Count, paramKeys.ToArray(), paramValues.ToArray()); if (outputs.Length > 0) { gcHandle?.Free(); return; } outputHandles = new NDArrayHandle[numOutputs]; Marshal.Copy(outputsReceiver, outputHandles, 0, numOutputs); foreach (var outputHandle in outputHandles) { outputs.Add(new NDArray(outputHandle)); } gcHandle?.Free(); } catch (Exception ex) { throw ex; } finally { gcHandle?.Free(); } }
public NDArrayList Call(NDArrayList args) { NativeMethods.MXInvokeCachedOpEx(handle, args.Length, MxUtil.GetNDArrayHandles(args), out var num_outputs, out var outputs, out var out_stypes); var result = new NDArrayList(); for (var i = 0; i < num_outputs; i++) { result.Add(new NDArray(outputs[i]).ToSType((StorageStype)out_stypes[i])); } return(result.ToArray()); }
public void UpdateMetric(EvalMetric metric, NDArrayList labels, bool pre_sliced = false) { var labels_slice = new NDArrayList(); var i = 0; train_execs.Zip(slices, (e, s) => { if (!pre_sliced) { foreach (var label in labels) { labels_slice.Add(label.Slice(s.Begin, s.End.Value)); } } else { labels_slice.Add(labels[i]); } metric.Update(labels_slice.ToArray(), e.Outputs.ToArray()); i++; return(true); }); }
public Executor(ExecutorHandle h, Context context, List <OpGradReq> gradReqs, Dictionary <string, Context> groupToCtx) { if (h == IntPtr.Zero) { throw new ArgumentException("Can not pass IntPtr.Zero", nameof(h)); } Outputs = new NDArrayList(); Handle = h; Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (uint i = 0; i < outSize; ++i) { Outputs.Add(new NDArray(outArrayArray[i])); } }
public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null, bool retain_graph = false, bool create_graph = true, bool train_mode = true) { var(head_handles, head_grads_handles) = ParseHead(heads, head_grads); //var grad_handles = new IntPtr[head_handles.Length]; //var grad_stypes = new int[head_handles.Length]; NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length, MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph), Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes); var result = new NDArrayList(); foreach (var item in grad_handles) { result.Add(new NDArray(item)); } return(result.ToArray()); }
public Executor SimpleBind(Context ctx, Dictionary <string, OpGradReq> grad_req = null, Dictionary <string, DType> type_dict = null, Dictionary <string, StorageStype> stype_dict = null, Dictionary <string, Context> group2ctx = null, string[] shared_arg_names = null, Executor shared_exec = null, NDArrayDict shared_buffer = null, DataDesc[] kwargs = null) { int num_provided_arg_types = 0; string[] provided_arg_type_names = null; int[] provided_arg_type_data = null; if (type_dict != null) { provided_arg_type_names = type_dict.Keys.ToArray(); provided_arg_type_data = type_dict.Values.Select(x => x.Index).ToArray(); num_provided_arg_types = type_dict.Count; } int num_provided_arg_stypes = 0; string[] provided_arg_stype_names = null; int[] provided_arg_stype_data = null; if (stype_dict != null) { provided_arg_stype_names = stype_dict.Keys.ToArray(); provided_arg_stype_data = stype_dict.Values.Select(x => (int)x).ToArray(); num_provided_arg_stypes = stype_dict.Count; } List <int> provided_arg_shape_data = new List <int>(); List <int> provided_arg_shape_idx = new List <int>() { 0 }; List <string> provided_arg_shape_names = new List <string>(); foreach (var desc in kwargs) { provided_arg_shape_names.Add(desc.Name); provided_arg_shape_data.AddRange(desc.Shape.Data.Where(x => x > 0).ToList()); provided_arg_shape_idx.Add(provided_arg_shape_data.Count); } int provided_req_type_list_len = 0; string[] provided_grad_req_names = new string[0]; string[] provided_grad_req_types = new string[0]; if (grad_req != null) { provided_grad_req_names = grad_req.Keys.ToArray(); provided_grad_req_types = grad_req.Values.Select(x => Enum.GetName(x.GetType(), x).ToLower()).ToArray(); provided_req_type_list_len = grad_req.Count; } int num_ctx_map_keys = 0; string[] ctx_map_keys = new string[0]; int[] ctx_map_dev_types = new int[0]; int[] ctx_map_dev_ids = new int[0]; if (group2ctx != null) { ctx_map_keys = group2ctx.Keys.ToArray(); ctx_map_dev_types = group2ctx.Values.Select(x => (int)x.GetDeviceType()).ToArray(); ctx_map_dev_ids = group2ctx.Values.Select(x => x.GetDeviceId()).ToArray(); num_ctx_map_keys = group2ctx.Count; } string[] shared_arg_name_list = new string[0]; if (shared_arg_names != null) { shared_arg_name_list = shared_arg_names; } unsafe { int shared_start = -1; int shared_buffer_len = shared_start; string[] shared_buffer_names = null; IntPtr[] shared_buffer_handles = null; if (shared_buffer.Count > 0) { shared_buffer_len = shared_buffer.Count; shared_buffer_names = shared_buffer.Keys; shared_buffer_handles = shared_buffer.Values.Handles; } var shared_exec_handle = shared_exec != null ? shared_exec.Handle : new ExecutorHandle(); char ** updated_shared_buffer_names; SymbolHandle * updated_shared_buffer_handles; int num_in_args; SymbolHandle * in_arg_handles; SymbolHandle * arg_grad_handles; SymbolHandle * aux_state_handles; ExecutorHandle exe_handle; int num_aux_states; NativeMethods.MXExecutorSimpleBindEx(NativePtr, (int)ctx.GetDeviceType(), ctx.GetDeviceId(), num_ctx_map_keys, ctx_map_keys, ctx_map_dev_types, ctx_map_dev_ids, provided_req_type_list_len, provided_grad_req_names, provided_grad_req_types, provided_arg_shape_names.Count, provided_arg_shape_names.ToArray(), provided_arg_shape_data.ToArray(), provided_arg_shape_idx.ToArray(), num_provided_arg_types, provided_arg_type_names, provided_arg_type_data, num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stype_data, shared_arg_name_list.Length, shared_arg_name_list, &shared_buffer_len, shared_buffer_names, shared_buffer_handles, &updated_shared_buffer_names, &updated_shared_buffer_handles, &num_in_args, &in_arg_handles, &arg_grad_handles, &num_aux_states, &aux_state_handles, shared_exec_handle, &exe_handle); if (shared_buffer.Count > 0) { int l = shared_buffer_len; for (int i = 0; i < l; i++) { string k = new string(updated_shared_buffer_names[i]); NDArray v = new NDArray(updated_shared_buffer_handles[i]); shared_buffer[k] = v; } } NDArrayList arg_arrays = new NDArrayList(); for (int i = 0; i < num_in_args; i++) { arg_arrays.Add(new NDArray(in_arg_handles[i])); } NDArrayList grad_arrays = new NDArrayList(); for (int i = 0; i < num_in_args; i++) { if (arg_grad_handles[i] == IntPtr.Zero) { continue; } grad_arrays.Add(new NDArray(arg_grad_handles[i])); } NDArrayList aux_arrays = new NDArrayList(); for (int i = 0; i < num_aux_states; i++) { aux_arrays.Add(new NDArray(aux_state_handles[i])); } var executor = new Executor(exe_handle, ctx, grad_req.Values.ToList(), group2ctx); executor.ArgmentArrays = arg_arrays; executor.GradientArrays = grad_arrays; executor.AuxiliaryArrays = aux_arrays; executor._Symbol = this; return(executor); } }
public void InferExecutorArrays(Context context, NDArrayList argArrays, NDArrayList gradArrays, IList <OpGradReq> gradReqs, NDArrayList auxArrays, NDArrayDict argsMap, NDArrayDict argGradStore, IDictionary <string, OpGradReq> gradReqType, NDArrayDict auxMap) { if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argArrays == null) { throw new ArgumentNullException(nameof(argArrays)); } if (gradArrays == null) { throw new ArgumentNullException(nameof(gradArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxArrays == null) { throw new ArgumentNullException(nameof(auxArrays)); } if (argsMap == null) { throw new ArgumentNullException(nameof(argsMap)); } if (argGradStore == null) { throw new ArgumentNullException(nameof(argGradStore)); } if (gradReqType == null) { throw new ArgumentNullException(nameof(gradReqType)); } if (auxMap == null) { throw new ArgumentNullException(nameof(auxMap)); } ThrowIfDisposed(); var argNameList = ListArguments(); var argShapes = new Dictionary <string, Shape>(); foreach (var argName in argNameList) { if (argsMap[argName] != null) { argShapes[argName] = argsMap[argName].Shape; } } var(inShapes, auxShapes, outShapes) = InferShape(argShapes); for (var i = 0; i < inShapes.Length; ++i) { var shape = inShapes[i]; var argName = argNameList[i]; if (argsMap[argName] != null) { argArrays.Add(argsMap[argName]); } else { argArrays.Add(new NDArray(shape, false)); //NDArray.SampleGaussian(0, 1, argArrays.Last()); var argArr = argArrays.Last(); nd.Random.Uniform(0, 1, argArr.Shape).CopyTo(argArr); } if (argGradStore[argName] != null) { gradArrays.Add(argGradStore[argName]); } else { gradArrays.Add(new NDArray(shape, false)); } if (gradReqType.TryGetValue(argName, out var value3)) { gradReqs.Add(value3); } else if (argName.LastIndexOf("data", StringComparison.InvariantCulture) == argName.Length - 4 || argName.LastIndexOf("label", StringComparison.InvariantCulture) == argName.Length - 5) { gradReqs.Add(OpGradReq.Null); } else { gradReqs.Add(OpGradReq.Write); } } var auxNameList = ListAuxiliaryStates(); for (var i = 0; i < auxShapes.Length; ++i) { var shape = auxShapes[i]; var auxName = auxNameList[i]; if (auxMap[auxName] != null) { auxArrays.Add(auxMap[auxName]); } else { auxArrays.Add(new NDArray(shape, false)); var aux = auxArrays.Last(); //NDArray.SampleGaussian(0, 1, auxArrays.Last()); nd.Random.Uniform(0, 1, aux.Shape).CopyTo(aux); } } }
public DataParallelExecutorGroup(Symbol sym, string[] arg_names, string[] param_names, Context[] ctxlist, Slice[] slices, DataIter train_data, DataParallelExecutorGroup shared_group = null) { ExecuterManager.CheckArguments(sym); if (shared_group == null) { foreach (var item in ctxlist) { shared_data_arrays.Add(new NDArrayDict()); } } else { shared_data_arrays = shared_group.shared_data_arrays; } foreach (var item in train_data.ProvideData) { data_names.Add(item.Name); } foreach (var item in train_data.ProvideLabel) { label_names.Add(item.Name); } aux_names = sym.ListAuxiliaryStates().ToList(); for (var i = 0; i < arg_names.Length; i++) { if (param_names.Contains(arg_names[i])) { param_idx.Add(i); this.param_names.Add(arg_names[i]); } } for (var i = 0; i < ctxlist.Length; i++) { var data_shapes = new Dictionary <string, Shape>(); var data_types = new Dictionary <string, DType>(); var shapeData = new List <int>(); foreach (var item in train_data.ProvideData) { shapeData = item.Shape.Data.ToList(); shapeData.RemoveAt(0); shapeData.Insert(0, slices[i].End.Value - slices[i].Begin); data_shapes[item.Name] = new Shape(shapeData); data_types[item.Name] = item.DataType; } foreach (var item in train_data.ProvideLabel) { shapeData = item.Shape.Data.ToList(); shapeData.RemoveAt(0); shapeData.Insert(0, slices[i].End.Value - slices[i].Begin); data_shapes[item.Name] = new Shape(shapeData); data_types[item.Name] = item.DataType; } var shared_exec = shared_group == null ? null : shared_group.train_execs[i]; var train_exec = ExecuterManager.BindExec(sym, ctxlist[i], data_shapes, param_names, true, shared_exec, shared_data_arrays[i], data_types); train_execs.Add(train_exec); } foreach (var name in data_names) { for (var i = 0; i < train_execs.Count; i++) { data_arrays.Add(train_execs[i].ArgmentDictionary()[name]); } } foreach (var name in label_names) { for (var i = 0; i < train_execs.Count; i++) { label_arrays.Add(train_execs[i].ArgmentDictionary()[name]); } } foreach (var idx in param_idx) { for (var i = 0; i < train_execs.Count; i++) { param_arrays.Add(train_execs[i].ArgmentArrays[idx]); } } foreach (var idx in param_idx) { for (var i = 0; i < train_execs.Count; i++) { grad_arrays.Add(train_execs[i].GradientArrays[idx]); } } for (var idx = 0; idx < aux_names.Count; idx++) { for (var i = 0; i < train_execs.Count; i++) { aux_arrays.Add(train_execs[i].AuxiliaryArrays[i]); } } this.slices = slices; }
public Executor(Symbol symbol, Context context, NDArrayList argmentArrays, NDArrayList gradientArrays, IList <OpGradReq> gradReqs, NDArrayList auxiliaryArrays, IDictionary <string, Context> groupToCtx, Executor sharedExec) { if (symbol == null) { throw new ArgumentNullException(nameof(symbol)); } if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argmentArrays == null) { throw new ArgumentNullException(nameof(argmentArrays)); } if (gradientArrays == null) { throw new ArgumentNullException(nameof(gradientArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxiliaryArrays == null) { throw new ArgumentNullException(nameof(auxiliaryArrays)); } if (groupToCtx == null) { throw new ArgumentNullException(nameof(groupToCtx)); } ArgmentArrays = argmentArrays; GradientArrays = gradientArrays; AuxiliaryArrays = auxiliaryArrays; _Symbol = symbol; var argHandles = argmentArrays.Select(array => array.GetHandle()).ToArray(); var gradHandles = gradientArrays.Select(array => array.GetHandle()).ToArray(); var auxHandles = auxiliaryArrays.Select(array => array.GetHandle()).ToArray(); var gradReqsUint = gradReqs.Select(s => (uint)s).ToArray(); var mapKeys = new string[groupToCtx.Count]; var devTypes = new int[groupToCtx.Count]; var devIds = new int[groupToCtx.Count]; var keys = groupToCtx.Keys.ToArray(); for (var index = 0; index < keys.Length; index++) { var key = keys[index]; mapKeys[index] = key; var value = groupToCtx[key]; devTypes[index] = (int)value.GetDeviceType(); devIds[index] = value.GetDeviceId(); } var sharedExecHandle = sharedExec?.Handle ?? IntPtr.Zero; Logging.CHECK_EQ(NativeMethods.MXExecutorBindEX(symbol.GetHandle(), (int)context.GetDeviceType(), context.GetDeviceId(), (uint)groupToCtx.Count, mapKeys, devTypes, devIds, (uint)argHandles.Length, argHandles, gradHandles, gradReqsUint, (uint)auxHandles.Length, auxHandles, sharedExecHandle, out var handle), NativeMethods.OK); Handle = handle; Outputs = new NDArrayList(); Logging.CHECK_EQ(NativeMethods.MXExecutorOutputs(Handle, out var outSize, out var outArray), 0); var outArrayArray = InteropHelper.ToPointerArray(outArray, outSize); for (uint i = 0; i < outSize; ++i) { Outputs.Add(new NDArray(outArrayArray[i])); } }
internal static Executor BindExec(Symbol sym, Context ctx, Dictionary <string, Shape> input_shapes, string[] param_names, bool need_grad = false, Executor base_exec = null, NDArrayDict shared_data_arrays = null, Dictionary <string, DType> input_types = null, Logger logger = null) { var(arg_shape, _, aux_shape) = sym.InferShape(input_shapes); if (arg_shape == null) { throw new ArgumentNullException("arg_shape"); } if (input_types == null) { input_types = new Dictionary <string, DType>(); foreach (var item in input_shapes.Keys) { input_types.Add(item, DType.Float32); } } var(arg_types, _, aux_types) = sym.InferType(input_types); if (arg_types == null) { throw new ArgumentNullException("arg_types"); } var arg_arrays = new NDArrayList(); var aux_arrays = new NDArrayList(); var grad_arrays = need_grad ? new NDArrayDict() : null; var arg_names = sym.ListArguments(); var needGradSet = new List <string>(); if (!need_grad) { needGradSet = new List <string>(); } else { foreach (var item in arg_names) { if (!input_shapes.ContainsKey(item)) { needGradSet.Add(item); } } needGradSet = MxUtil.Set(needGradSet); } var grad_req = new Dictionary <string, OpGradReq>(); foreach (var item in arg_names) { if (needGradSet.Contains(item)) { grad_req.Add(item, OpGradReq.Write); } } for (var i = 0; i < arg_names.Count; i++) { var name = arg_names[i]; NDArray arg_arr = null; NDArray grad_arr = null; if (!param_names.Contains(name)) { if (shared_data_arrays != null && shared_data_arrays.Contains(name)) { arg_arr = shared_data_arrays[name]; if (np.prod(arg_arr.Shape.Data) >= np.prod(arg_shape[i].Data)) { if (arg_types[i].Name != arg_arr.DataType.Name) { throw new ArgumentException("arg_type and arg_arr datatype mismatch"); } arg_arr = arg_arr.Reshape(arg_shape[i]); } else { var logmsg = new StringBuilder(); logmsg.AppendFormat("bucketing: data \"{0}\" has a shape {1}", name, arg_shape[i]); logmsg.AppendFormat(", which is larger than already allocated "); logmsg.AppendFormat("shape {0}", arg_arr.Shape); logmsg.AppendFormat(". Need to re-allocate. Consider putting default_bucket_key " + "to be the bucket taking the largest input for better memory sharing."); Logger.Warning(logmsg.ToString()); arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); shared_data_arrays[name] = arg_arr; } } else { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (shared_data_arrays != null) { shared_data_arrays[name] = arg_arr; } } arg_arrays.Add(arg_arr); } else { if (base_exec == null) { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (needGradSet.Contains(name)) { grad_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); grad_arrays[name] = grad_arr; } else { arg_arr = base_exec.ArgmentDictionary()[name]; if (arg_arr.Shape != arg_shape[i]) { throw new ArgumentException("arg_arr.Shape != arg_shape[i]"); } if (arg_arr.DataType != arg_types[i]) { throw new ArgumentException("arg_arr.DataType != arg_types[i]"); } if (needGradSet.Contains(name)) { grad_arrays[name] = base_exec.GradientDictionary()[name]; } } arg_arrays.Add(arg_arr); } } } if (base_exec != null) { for (var i = 0; i < aux_shape.Length; i++) { var s = aux_shape[i]; var t = aux_types[i]; aux_arrays.Add(nd.Zeros(s, ctx, t)); } } else { foreach (var item in base_exec.AuxiliaryDictionary()) { aux_arrays.Add(item.Value); } } var executor = sym.Bind(ctx, arg_arrays, grad_arrays.Values.ToList(), grad_req.Values.ToList(), aux_arrays); return(executor); }