public static (Symbol, NDArrayDict, NDArrayDict) LoadCheckpoint(string prefix, int epoch) { Symbol sym = Symbol.Load($"{prefix}-symbol.json"); string param_name = $"{prefix}-{epoch.ToString("D4")}.params"; var save_dict = NDArray.Load(param_name); NDArrayDict arg_params = new NDArrayDict(); NDArrayDict aux_params = new NDArrayDict(); if (save_dict == null) { Logger.Warning($"Params file '{param_name}' is empty"); } else { foreach (var item in save_dict) { if (item.Key.StartsWith("arg:")) { arg_params.Add(item.Key.Replace("arg:", ""), item.Value); } else if (item.Key.StartsWith("aux:")) { aux_params.Add(item.Key.Replace("aux:", ""), item.Value); } else { Logger.Warning($"Params file '{param_name}' contains unknown param '{item.Key}'"); } } } return(sym, arg_params, aux_params); }
public void SetParams(NDArrayDict arg_params, NDArrayDict aux_params) { foreach (var texec in execgrp.train_execs) { texec.CopyFromParams(arg_params, aux_params); } }
public static void SaveCheckpoint(string prefix, int epoch, Symbol symbol, NDArrayDict arg_params, NDArrayDict aux_params, bool remove_amp_cast = true) { if (symbol != null) { symbol.Save($"{prefix}-symbol.json", remove_amp_cast); } NDArrayDict save_dict = new NDArrayDict(); foreach (var item in arg_params) { save_dict.Add($"arg:{item.Key}", item.Value); } foreach (var item in aux_params) { save_dict.Add($"aux:{item.Key}", item.Value); } string param_name = $"{prefix}-{epoch.ToString("D4")}.params"; NDArray.Save(param_name, save_dict); Logger.Info($"Saved checkpoint to \"{param_name}\""); }
public void InferArgsMap(Context context, NDArrayDict argsMap, NDArrayDict knownArgs) { if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argsMap == null) { throw new ArgumentNullException(nameof(argsMap)); } if (knownArgs == null) { throw new ArgumentNullException(nameof(knownArgs)); } ThrowIfDisposed(); var argShapes = new Dictionary <string, Shape>(); var argNameList = ListArguments(); foreach (var argName in argNameList) { if (knownArgs[argName] != null) { argShapes[argName] = knownArgs[argName].Shape; } } var(inShapes, outShapes, auxShapes) = InferShape(argShapes); for (var i = 0; i < inShapes.Length; ++i) { var shape = inShapes[i]; var argName = argNameList[i]; if (knownArgs[argName] != null) { argsMap[argName] = knownArgs[argName]; } else { var array = new NDArray(shape, false); argsMap[argName] = array; //NDArray.SampleGaussian(0, 1, array); nd.Random.Uniform(0, 1, array.Shape).CopyTo(array); } } }
public void InferExecutorArrays(Context context, NDArrayList argArrays, NDArrayList gradArrays, IList <OpGradReq> gradReqs, NDArrayList auxArrays, NDArrayDict argsMap) { InferExecutorArrays(context, argArrays, gradArrays, gradReqs, auxArrays, argsMap, new NDArrayDict()); }
public void CopyTo(NDArrayDict arg_params, NDArrayDict aux_params) { //ToDo: Revisit code param_names.Zip(ParamArrays, (name, block) => { var w = new NDArray(new[] { block.Sum() }, Context.Cpu()); w.AsType(arg_params[name].DataType).CopyTo(arg_params[name]); return(true); }); aux_names.Zip(AuxArrays, (name, block) => { var w = new NDArray(new[] { block.Sum() }, Context.Cpu()); w.AsType(aux_params[name].DataType).CopyTo(aux_params[name]); return(true); }); }
public void InferExecutorArrays(Context context, NDArrayList argArrays, NDArrayList gradArrays, IList <OpGradReq> gradReqs, NDArrayList auxArrays, NDArrayDict argsMap, NDArrayDict argGradStore) { InferExecutorArrays(context, argArrays, gradArrays, gradReqs, auxArrays, argsMap, argGradStore, new Dictionary <string, OpGradReq>()); }
public static NDArrayDict GetMNIST() { var path = "http://data.mxnet.io/data/mnist/"; var(train_lbl, train_img) = read_data(path + "train-labels-idx1-ubyte.gz", path + "train-images-idx3-ubyte.gz", 60000); var(test_lbl, test_img) = read_data(path + "t10k-labels-idx1-ubyte.gz", path + "t10k-images-idx3-ubyte.gz", 10000); var dataset = new NDArrayDict(); dataset.Add("train_data", train_img); dataset.Add("train_label", train_lbl); dataset.Add("test_data", test_img); dataset.Add("test_label", test_lbl); return(dataset); }
private static NDArrayDict GetDictionary(IList <string> names, NDArrayList arrays) { var ret = new NDArrayDict(); var set = new HashSet <string>(); foreach (var s in names) { Logging.CHECK(!set.Contains(s), $"Duplicate names detected, {s}"); set.Add(s); } Logging.CHECK_EQ(set.Count, arrays.Length, "names size not equal to arrays size"); for (var i = 0; i < names.Count; ++i) { ret[names[i]] = arrays[i]; } return(ret); }
public void CopyFromParams(NDArrayDict arg_params, NDArrayDict aux_params = null, bool allow_extra_params = false) { var arg_dict = ArgmentDictionary(); var aux_dict = AuxiliaryDictionary(); foreach (var item in arg_params) { if (arg_dict.Contains(item.Key)) { var dst = arg_dict[item.Key]; item.Value.AsType(dst.DataType).CopyTo(dst); } else if (!allow_extra_params) { throw new Exception($"Find name \"{item.Key}\" that is not in the arguments"); } } if (aux_params == null) { return; } foreach (var item in aux_params) { if (aux_dict.Contains(item.Key)) { var dst = aux_dict[item.Key]; item.Value.AsType(dst.DataType).CopyTo(dst); } else if (!allow_extra_params) { throw new Exception($"Find name \"{item.Key}\" that is not in the auxiliary states"); } } }
internal static (KVStore, bool) CreateKVStore(KVStore kvstore, int num_device, NDArrayDict arg_params) { var update_on_kvstore = true; if (!string.IsNullOrWhiteSpace(Environment.GetEnvironmentVariable("MXNET_UPDATE_ON_KVSTORE"))) { update_on_kvstore = Convert.ToBoolean(Environment.GetEnvironmentVariable("MXNET_UPDATE_ON_KVSTORE")); } if (kvstore == null) { update_on_kvstore = false; } return(kvstore, update_on_kvstore); }
internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null) { var executor_manager = new DataParallelExecutorManager(symbol: symbol, ctx: ctx, train_data: train_data, arg_names: arg_names, param_names: param_names, aux_names: aux_names, work_load_list: work_load_list, sym_gen: sym_gen); if (monitor != null) { executor_manager.InstallMonitor(monitor); } executor_manager.SetParams(arg_params, aux_params); Updater updater = null; if (!update_on_kvstore) { updater = Optimizer.GetUpdater(optimizer); } else { kvstore.SetOptimizer(optimizer); } if (kvstore != null) { InitializeKVStore(kvstore: kvstore, param_arrays: new List <NDArrayList>() { executor_manager.ParamArrays }, arg_params: arg_params, param_names: executor_manager.param_names, update_on_kvstore: update_on_kvstore); } train_data.Reset(); for (int epoch = begin_epoch; epoch < end_epoch; epoch++) { var tic = DateTime.Now; eval_metric.Reset(); int nbatch = 0; while (true) { bool do_reset = true; while (!train_data.End()) { var data_batch = train_data.Next(); executor_manager.LoadDataBatch(data_batch); if (monitor != null) { monitor.Tic(); } executor_manager.Forward(true); executor_manager.Backward(); if (update_on_kvstore) { if (kvstore.Type.Contains("nccl")) { UpdateParamsOnKVStoreNCCL(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } else { UpdateParamsOnKVStore(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } } else { UpdateParams(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, updater, ctx.Length, kvstore, executor_manager.param_names); } if (monitor != null) { monitor.TocPrint(); } executor_manager.UpdateMetric(eval_metric, data_batch.Label); nbatch++; if (batch_end_callback != null) { MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric); } if (epoch_size.HasValue && nbatch >= epoch_size.Value) { do_reset = false; break; } } if (do_reset) { Logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); train_data.Reset(); } if (epoch_size.HasValue) { if (nbatch >= epoch_size.Value) { break; } else { break; } } } var toc = DateTime.Now; Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}"); if (epoch_end_callback != null || epoch + 1 == end_epoch) { executor_manager.CopyTo(arg_params, aux_params); } MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params); if (eval_data != null) { eval_metric.Reset(); eval_data.Reset(); int total_num_batch = 0; int i = 0; while (!eval_data.End()) { var eval_batch = eval_data.Next(); executor_manager.LoadDataBatch(eval_batch); executor_manager.Forward(); executor_manager.UpdateMetric(eval_metric, eval_batch.Label); if (eval_batch_end_callback != null) { MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric); } total_num_batch++; } if (eval_end_callback != null) { MultipleCallbacks(eval_end_callback, epoch, eval_metric); } } } }
public Executor SimpleBind(Context ctx, Dictionary <string, OpGradReq> grad_req = null, Dictionary <string, DType> type_dict = null, Dictionary <string, StorageStype> stype_dict = null, Dictionary <string, Context> group2ctx = null, string[] shared_arg_names = null, Executor shared_exec = null, NDArrayDict shared_buffer = null, DataDesc[] kwargs = null) { int num_provided_arg_types = 0; string[] provided_arg_type_names = null; int[] provided_arg_type_data = null; if (type_dict != null) { provided_arg_type_names = type_dict.Keys.ToArray(); provided_arg_type_data = type_dict.Values.Select(x => x.Index).ToArray(); num_provided_arg_types = type_dict.Count; } int num_provided_arg_stypes = 0; string[] provided_arg_stype_names = null; int[] provided_arg_stype_data = null; if (stype_dict != null) { provided_arg_stype_names = stype_dict.Keys.ToArray(); provided_arg_stype_data = stype_dict.Values.Select(x => (int)x).ToArray(); num_provided_arg_stypes = stype_dict.Count; } List <int> provided_arg_shape_data = new List <int>(); List <int> provided_arg_shape_idx = new List <int>() { 0 }; List <string> provided_arg_shape_names = new List <string>(); foreach (var desc in kwargs) { provided_arg_shape_names.Add(desc.Name); provided_arg_shape_data.AddRange(desc.Shape.Data.Where(x => x > 0).ToList()); provided_arg_shape_idx.Add(provided_arg_shape_data.Count); } int provided_req_type_list_len = 0; string[] provided_grad_req_names = new string[0]; string[] provided_grad_req_types = new string[0]; if (grad_req != null) { provided_grad_req_names = grad_req.Keys.ToArray(); provided_grad_req_types = grad_req.Values.Select(x => Enum.GetName(x.GetType(), x).ToLower()).ToArray(); provided_req_type_list_len = grad_req.Count; } int num_ctx_map_keys = 0; string[] ctx_map_keys = new string[0]; int[] ctx_map_dev_types = new int[0]; int[] ctx_map_dev_ids = new int[0]; if (group2ctx != null) { ctx_map_keys = group2ctx.Keys.ToArray(); ctx_map_dev_types = group2ctx.Values.Select(x => (int)x.GetDeviceType()).ToArray(); ctx_map_dev_ids = group2ctx.Values.Select(x => x.GetDeviceId()).ToArray(); num_ctx_map_keys = group2ctx.Count; } string[] shared_arg_name_list = new string[0]; if (shared_arg_names != null) { shared_arg_name_list = shared_arg_names; } unsafe { int shared_start = -1; int shared_buffer_len = shared_start; string[] shared_buffer_names = null; IntPtr[] shared_buffer_handles = null; if (shared_buffer.Count > 0) { shared_buffer_len = shared_buffer.Count; shared_buffer_names = shared_buffer.Keys; shared_buffer_handles = shared_buffer.Values.Handles; } var shared_exec_handle = shared_exec != null ? shared_exec.Handle : new ExecutorHandle(); char ** updated_shared_buffer_names; SymbolHandle * updated_shared_buffer_handles; int num_in_args; SymbolHandle * in_arg_handles; SymbolHandle * arg_grad_handles; SymbolHandle * aux_state_handles; ExecutorHandle exe_handle; int num_aux_states; NativeMethods.MXExecutorSimpleBindEx(NativePtr, (int)ctx.GetDeviceType(), ctx.GetDeviceId(), num_ctx_map_keys, ctx_map_keys, ctx_map_dev_types, ctx_map_dev_ids, provided_req_type_list_len, provided_grad_req_names, provided_grad_req_types, provided_arg_shape_names.Count, provided_arg_shape_names.ToArray(), provided_arg_shape_data.ToArray(), provided_arg_shape_idx.ToArray(), num_provided_arg_types, provided_arg_type_names, provided_arg_type_data, num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stype_data, shared_arg_name_list.Length, shared_arg_name_list, &shared_buffer_len, shared_buffer_names, shared_buffer_handles, &updated_shared_buffer_names, &updated_shared_buffer_handles, &num_in_args, &in_arg_handles, &arg_grad_handles, &num_aux_states, &aux_state_handles, shared_exec_handle, &exe_handle); if (shared_buffer.Count > 0) { int l = shared_buffer_len; for (int i = 0; i < l; i++) { string k = new string(updated_shared_buffer_names[i]); NDArray v = new NDArray(updated_shared_buffer_handles[i]); shared_buffer[k] = v; } } NDArrayList arg_arrays = new NDArrayList(); for (int i = 0; i < num_in_args; i++) { arg_arrays.Add(new NDArray(in_arg_handles[i])); } NDArrayList grad_arrays = new NDArrayList(); for (int i = 0; i < num_in_args; i++) { if (arg_grad_handles[i] == IntPtr.Zero) { continue; } grad_arrays.Add(new NDArray(arg_grad_handles[i])); } NDArrayList aux_arrays = new NDArrayList(); for (int i = 0; i < num_aux_states; i++) { aux_arrays.Add(new NDArray(aux_state_handles[i])); } var executor = new Executor(exe_handle, ctx, grad_req.Values.ToList(), group2ctx); executor.ArgmentArrays = arg_arrays; executor.GradientArrays = grad_arrays; executor.AuxiliaryArrays = aux_arrays; executor._Symbol = this; return(executor); } }
public void InferExecutorArrays(Context context, NDArrayList argArrays, NDArrayList gradArrays, IList <OpGradReq> gradReqs, NDArrayList auxArrays, NDArrayDict argsMap, NDArrayDict argGradStore, IDictionary <string, OpGradReq> gradReqType, NDArrayDict auxMap) { if (context == null) { throw new ArgumentNullException(nameof(context)); } if (argArrays == null) { throw new ArgumentNullException(nameof(argArrays)); } if (gradArrays == null) { throw new ArgumentNullException(nameof(gradArrays)); } if (gradReqs == null) { throw new ArgumentNullException(nameof(gradReqs)); } if (auxArrays == null) { throw new ArgumentNullException(nameof(auxArrays)); } if (argsMap == null) { throw new ArgumentNullException(nameof(argsMap)); } if (argGradStore == null) { throw new ArgumentNullException(nameof(argGradStore)); } if (gradReqType == null) { throw new ArgumentNullException(nameof(gradReqType)); } if (auxMap == null) { throw new ArgumentNullException(nameof(auxMap)); } ThrowIfDisposed(); var argNameList = ListArguments(); var argShapes = new Dictionary <string, Shape>(); foreach (var argName in argNameList) { if (argsMap[argName] != null) { argShapes[argName] = argsMap[argName].Shape; } } var(inShapes, auxShapes, outShapes) = InferShape(argShapes); for (var i = 0; i < inShapes.Length; ++i) { var shape = inShapes[i]; var argName = argNameList[i]; if (argsMap[argName] != null) { argArrays.Add(argsMap[argName]); } else { argArrays.Add(new NDArray(shape, false)); //NDArray.SampleGaussian(0, 1, argArrays.Last()); var argArr = argArrays.Last(); nd.Random.Uniform(0, 1, argArr.Shape).CopyTo(argArr); } if (argGradStore[argName] != null) { gradArrays.Add(argGradStore[argName]); } else { gradArrays.Add(new NDArray(shape, false)); } if (gradReqType.TryGetValue(argName, out var value3)) { gradReqs.Add(value3); } else if (argName.LastIndexOf("data", StringComparison.InvariantCulture) == argName.Length - 4 || argName.LastIndexOf("label", StringComparison.InvariantCulture) == argName.Length - 5) { gradReqs.Add(OpGradReq.Null); } else { gradReqs.Add(OpGradReq.Write); } } var auxNameList = ListAuxiliaryStates(); for (var i = 0; i < auxShapes.Length; ++i) { var shape = auxShapes[i]; var auxName = auxNameList[i]; if (auxMap[auxName] != null) { auxArrays.Add(auxMap[auxName]); } else { auxArrays.Add(new NDArray(shape, false)); var aux = auxArrays.Last(); //NDArray.SampleGaussian(0, 1, auxArrays.Last()); nd.Random.Uniform(0, 1, aux.Shape).CopyTo(aux); } } }
internal static Executor BindExec(Symbol sym, Context ctx, Dictionary <string, Shape> input_shapes, string[] param_names, bool need_grad = false, Executor base_exec = null, NDArrayDict shared_data_arrays = null, Dictionary <string, DType> input_types = null, Logger logger = null) { var(arg_shape, _, aux_shape) = sym.InferShape(input_shapes); if (arg_shape == null) { throw new ArgumentNullException("arg_shape"); } if (input_types == null) { input_types = new Dictionary <string, DType>(); foreach (var item in input_shapes.Keys) { input_types.Add(item, DType.Float32); } } var(arg_types, _, aux_types) = sym.InferType(input_types); if (arg_types == null) { throw new ArgumentNullException("arg_types"); } var arg_arrays = new NDArrayList(); var aux_arrays = new NDArrayList(); var grad_arrays = need_grad ? new NDArrayDict() : null; var arg_names = sym.ListArguments(); var needGradSet = new List <string>(); if (!need_grad) { needGradSet = new List <string>(); } else { foreach (var item in arg_names) { if (!input_shapes.ContainsKey(item)) { needGradSet.Add(item); } } needGradSet = MxUtil.Set(needGradSet); } var grad_req = new Dictionary <string, OpGradReq>(); foreach (var item in arg_names) { if (needGradSet.Contains(item)) { grad_req.Add(item, OpGradReq.Write); } } for (var i = 0; i < arg_names.Count; i++) { var name = arg_names[i]; NDArray arg_arr = null; NDArray grad_arr = null; if (!param_names.Contains(name)) { if (shared_data_arrays != null && shared_data_arrays.Contains(name)) { arg_arr = shared_data_arrays[name]; if (np.prod(arg_arr.Shape.Data) >= np.prod(arg_shape[i].Data)) { if (arg_types[i].Name != arg_arr.DataType.Name) { throw new ArgumentException("arg_type and arg_arr datatype mismatch"); } arg_arr = arg_arr.Reshape(arg_shape[i]); } else { var logmsg = new StringBuilder(); logmsg.AppendFormat("bucketing: data \"{0}\" has a shape {1}", name, arg_shape[i]); logmsg.AppendFormat(", which is larger than already allocated "); logmsg.AppendFormat("shape {0}", arg_arr.Shape); logmsg.AppendFormat(". Need to re-allocate. Consider putting default_bucket_key " + "to be the bucket taking the largest input for better memory sharing."); Logger.Warning(logmsg.ToString()); arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); shared_data_arrays[name] = arg_arr; } } else { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (shared_data_arrays != null) { shared_data_arrays[name] = arg_arr; } } arg_arrays.Add(arg_arr); } else { if (base_exec == null) { arg_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); if (needGradSet.Contains(name)) { grad_arr = nd.Zeros(arg_shape[i], ctx, arg_types[i]); grad_arrays[name] = grad_arr; } else { arg_arr = base_exec.ArgmentDictionary()[name]; if (arg_arr.Shape != arg_shape[i]) { throw new ArgumentException("arg_arr.Shape != arg_shape[i]"); } if (arg_arr.DataType != arg_types[i]) { throw new ArgumentException("arg_arr.DataType != arg_types[i]"); } if (needGradSet.Contains(name)) { grad_arrays[name] = base_exec.GradientDictionary()[name]; } } arg_arrays.Add(arg_arr); } } } if (base_exec != null) { for (var i = 0; i < aux_shape.Length; i++) { var s = aux_shape[i]; var t = aux_types[i]; aux_arrays.Add(nd.Zeros(s, ctx, t)); } } else { foreach (var item in base_exec.AuxiliaryDictionary()) { aux_arrays.Add(item.Value); } } var executor = sym.Bind(ctx, arg_arrays, grad_arrays.Values.ToList(), grad_req.Values.ToList(), aux_arrays); return(executor); }
internal static (KVStore, bool) CreateKVStore(string kvstore, int num_device, NDArrayDict arg_params) { KVStore kV = null; var update_on_kvstore = true; if (num_device == 1 && !kvstore.Contains("dist")) { kV = null; } else { kV = KVStoreBase.Create(kvstore); if (kvstore == "local") { var max_size = arg_params.Values.Select(x => x.Shape.Size).ToList().Max(); if (max_size > 1024 * 1024 * 16) { update_on_kvstore = false; } } } return(kV, update_on_kvstore); }
internal static void InitializeKVStore(KVStore kvstore, List <NDArrayList> param_arrays, NDArrayDict arg_params, string[] param_names, bool update_on_kvstore) { for (int i = 0; i < param_arrays.Count; i++) { if (param_arrays[i].Length == 0) { continue; } if (param_arrays[i][0] == null) { continue; } var name = param_names[i]; var param_on_devs = param_arrays[i]; kvstore.Init(name, arg_params[name]); if (update_on_kvstore) { kvstore.Pull(name, param_on_devs, -i); } } }
public CachedOp(Symbol sym, NDArrayDict flags) { handle = IntPtr.Zero; NativeMethods.MXCreateCachedOpEx(sym.GetHandle(), flags.Count, flags.Keys.ToArray(), MxUtil.GetNDArrayHandles(flags.Values.ToArray()), out handle); }