public IEnumerable <(NDArrayList, int, DataBatch)> IterPredict(DataIter eval_data, int?num_batch = null, bool reset = true, int epoch = 0, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null) { if (!Binded && !ParamsInitialized) { throw new Exception("Module not binded and param initialized"); } if (reset) { eval_data.Reset(); } while (eval_data.End()) { if (num_batch.HasValue && eval_data.Cursor == num_batch.Value) { break; } var eval_batch = eval_data.Next(); Prepare(eval_batch, sparse_row_id_fn); Forward(eval_batch, false); var pad = eval_batch.Pad.Value; var outputs = new NDArrayList(); foreach (var list in GetOutputs()) { foreach (var @out in list) { outputs.Add(@out[$"0:{@out.Shape[0] - pad}"]); } } yield return(outputs.ToArray(), eval_data.Cursor, eval_batch); } }
internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null) { var executor_manager = new DataParallelExecutorManager(symbol: symbol, ctx: ctx, train_data: train_data, arg_names: arg_names, param_names: param_names, aux_names: aux_names, work_load_list: work_load_list, sym_gen: sym_gen); if (monitor != null) { executor_manager.InstallMonitor(monitor); } executor_manager.SetParams(arg_params, aux_params); Updater updater = null; if (!update_on_kvstore) { updater = Optimizer.GetUpdater(optimizer); } else { kvstore.SetOptimizer(optimizer); } if (kvstore != null) { InitializeKVStore(kvstore: kvstore, param_arrays: new List <NDArrayList>() { executor_manager.ParamArrays }, arg_params: arg_params, param_names: executor_manager.param_names, update_on_kvstore: update_on_kvstore); } train_data.Reset(); for (int epoch = begin_epoch; epoch < end_epoch; epoch++) { var tic = DateTime.Now; eval_metric.Reset(); int nbatch = 0; while (true) { bool do_reset = true; while (!train_data.End()) { var data_batch = train_data.Next(); executor_manager.LoadDataBatch(data_batch); if (monitor != null) { monitor.Tic(); } executor_manager.Forward(true); executor_manager.Backward(); if (update_on_kvstore) { if (kvstore.Type.Contains("nccl")) { UpdateParamsOnKVStoreNCCL(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } else { UpdateParamsOnKVStore(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } } else { UpdateParams(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, updater, ctx.Length, kvstore, executor_manager.param_names); } if (monitor != null) { monitor.TocPrint(); } executor_manager.UpdateMetric(eval_metric, data_batch.Label); nbatch++; if (batch_end_callback != null) { MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric); } if (epoch_size.HasValue && nbatch >= epoch_size.Value) { do_reset = false; break; } } if (do_reset) { Logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); train_data.Reset(); } if (epoch_size.HasValue) { if (nbatch >= epoch_size.Value) { break; } else { break; } } } var toc = DateTime.Now; Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}"); if (epoch_end_callback != null || epoch + 1 == end_epoch) { executor_manager.CopyTo(arg_params, aux_params); } MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params); if (eval_data != null) { eval_metric.Reset(); eval_data.Reset(); int total_num_batch = 0; int i = 0; while (!eval_data.End()) { var eval_batch = eval_data.Next(); executor_manager.LoadDataBatch(eval_batch); executor_manager.Forward(); executor_manager.UpdateMetric(eval_metric, eval_batch.Label); if (eval_batch_end_callback != null) { MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric); } total_num_batch++; } if (eval_end_callback != null) { MultipleCallbacks(eval_end_callback, epoch, eval_metric); } } } }
public void Fit(DataIter train_data, DataIter eval_data = null, string eval_metric = "acc", IEpochEndCallback[] epoch_end_callback = null, IBatchEndCallback[] batch_end_callback = null, string kvstore = "local", Optimizer optimizer = null, Dictionary <string, object> optimizer_params = null, IScoreEndCallback[] eval_end_callback = null, IBatchEndCallback[] eval_batch_end_callback = null, Initializer initializer = null, NDArrayDict arg_params = null, NDArrayDict aux_params = null, bool allow_missing = false, bool force_rebind = false, bool force_init = false, int begin_epoch = 0, int?num_epoch = null, EvalMetric validation_metric = null, Monitor monitor = null, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null) { object val; object name; if (optimizer == null) { optimizer = new SGD(); } Debug.Assert(num_epoch != null, "please specify number of epochs"); this.Bind(data_shapes: train_data.ProvideData, label_shapes: train_data.ProvideLabel, for_training: true, force_rebind: force_rebind); if (monitor != null) { this.InstallMonitor(monitor); } this.InitParams(initializer: initializer, arg_params: arg_params, aux_params: aux_params, allow_missing: allow_missing, force_init: force_init); this.InitOptimizer(kvstore: kvstore, optimizer: optimizer, optimizer_params: optimizer_params); if (validation_metric == null) { validation_metric = eval_metric; } var eval_metric_func = EvalMetric.Create(eval_metric, null); //############################################################################### // training loop //############################################################################### foreach (var epoch in Enumerable.Range(begin_epoch, num_epoch.Value - begin_epoch)) { var tic = DateTime.Now; eval_metric_func.Reset(); var nbatch = 0; var data_iter = train_data; var end_of_batch = false; var next_data_batch = data_iter.Next(); Dictionary <string, float> eval_name_vals = new Dictionary <string, float>(); while (!end_of_batch) { var data_batch = next_data_batch; if (monitor != null) { monitor.Tic(); } this.ForwardBackward(data_batch); this.Update(); UpdateMetric(eval_metric_func, data_batch.Label); try { // pre fetch next batch next_data_batch = data_iter.Next(); this.Prepare(next_data_batch, sparse_row_id_fn: sparse_row_id_fn); } catch (StopIteration) { end_of_batch = true; } if (monitor != null) { monitor.TocPrint(); } if (end_of_batch) { eval_name_vals = eval_metric_func.GetGlobalNameValue(); } if (batch_end_callback != null) { foreach (var callback in batch_end_callback) { callback.Invoke(epoch: epoch, nbatch: nbatch, eval_metric: eval_metric); } } nbatch += 1; } // one epoch of training is finished foreach (var item in eval_name_vals) { name = item.Key; val = item.Value; Logger.Info($"Epoch[{epoch}] Train-{name}={val}"); } var toc = DateTime.Now; Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}"); // sync aux params across devices (arg_params, aux_params) = this.GetParams(); this.SetParams(arg_params, aux_params); if (epoch_end_callback != null) { foreach (var callback in epoch_end_callback) { callback.Invoke(epoch, this.Symbol, arg_params, aux_params); } } //---------------------------------------- // evaluation on validation set if (eval_data != null) { var res = this.Score(eval_data, validation_metric, score_end_callback: eval_end_callback, batch_end_callback: eval_batch_end_callback, epoch: epoch); //TODO: pull this into default foreach (var item in res) { name = item.Key; val = item.Value; Logger.Info($"Epoch[{epoch}] Validation-{name}={val}"); } } // end of 1 epoch, reset the data-iter for another epoch train_data.Reset(); } }
public List <NDArrayList> Predict(DataIter eval_data, int?num_batch = null, bool merge_batches = true, bool reset = true, bool always_output_list = true, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null) { if (!Binded && !ParamsInitialized) { throw new Exception("Module not binded and param initialized"); } if (reset) { eval_data.Reset(); } var output_list = new List <NDArrayList>(); var output_list2 = new NDArrayList(); while (eval_data.End()) { if (num_batch.HasValue && eval_data.Cursor == num_batch.Value) { break; } var eval_batch = eval_data.Next(); Prepare(eval_batch, sparse_row_id_fn); Forward(eval_batch, false); var pad = eval_batch.Pad.Value; var outputs = new NDArrayList(); foreach (var list in GetOutputs()) { foreach (var @out in list) { outputs.Add(@out[$"0:{@out.Shape[0] - pad}"].Copy()); } } output_list.Add(outputs.ToArray()); } if (output_list.Count == 0) { return(output_list); } if (merge_batches) { var num_outputs = output_list[0].Length; foreach (var @out in output_list) { if (@out.Length != num_outputs) { throw new Exception("Cannot merge batches, as num of outputs is not the same " + "in mini-batches. Maybe bucketing is used?"); } output_list2.Add(nd.Concat(@out)); } return(new List <NDArrayList> { output_list2.ToArray() }); } return(output_list); }
public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false) { string labelName = "label"; var label = Symbol.Variable(labelName); List <uint> inputShape = new List <uint>(); inputShape.Add(batchSize); inputShape.AddRange(InputShape); args["X"] = new NDArray(new Shape(inputShape.ToArray())); args[labelName] = new NDArray(new Shape(batchSize)); Model.InferArgsMap(mx.Device, args, args); var defaultInitializer = new Initializers.GlorotUniform(); foreach (var arg in args) { if (ParamInitializers.ContainsKey(arg.Key)) { ParamInitializers[arg.Key].Generate(arg.Value); } else { defaultInitializer.Generate(arg.Value); } } using (var exec = Model.SimpleBind(mx.Device, args)) { var argNames = Model.ListArguments(); // Start training var sw = new Stopwatch(); for (var iter = 1; iter <= epochs; iter++) { uint samples = 0; train.BatchSize = batchSize; train.Reset(); Metric.Reset(); TrainMetric.Reset(); sw.Restart(); while (train.IterNext()) { samples += batchSize; var dataBatch = train.Next(); // Set data and label dataBatch.Data[0].CopyTo(args["X"]); dataBatch.Label[0].CopyTo(args[labelName]); // Compute gradients exec.Forward(true); exec.Backward(); TrainMetric.Update(args[labelName], exec.Output); // Update parameters for (var i = 0; i < argNames.Count; ++i) { if (argNames[i] == "X" || argNames[i] == labelName) { continue; } ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i], null); } } sw.Stop(); if (validation != null) { validation.BatchSize = batchSize; validation.Reset(); while (validation.IterNext()) { var dataBatch = validation.Next(); dataBatch.Data[0].CopyTo(args["X"]); dataBatch.Label[0].CopyTo(args[labelName]); NDArray.WaitAll(); // Forward pass is enough as no gradient is needed when evaluating exec.Forward(false); Metric.Update(args[labelName], exec.Output); } } var duration = sw.ElapsedMilliseconds == 0 ? 1 : sw.ElapsedMilliseconds; if (validation == null) { Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec Train_Metric: {TrainMetric.Get()}"); } else { Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec, Train_Metric: {TrainMetric.Get()}, Val_Metric: {Metric.Get()}"); } } } //MXNet.MXNotifyShutdown(); }
public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false) { var args = new SortedDictionary <string, NDArray>(); string labelName = "label"; var label = Symbol.Variable(labelName); args["X"] = new NDArray(new Shape(batchSize, (uint)InputShape[0])); args[labelName] = new NDArray(new Shape(batchSize, (uint)OutputShape.Size)); CompiledModel.InferArgsMap(GlobalParam.Device, args, args); var initializer = new SiaDNN.Initializers.GlorotUniform(); foreach (var arg in args) { initializer.Operator(arg.Key, arg.Value); } ModelOptimizer.SetParam("rescale_grad", 1.0 / batchSize); using (var exec = CompiledModel.SimpleBind(GlobalParam.Device, args)) { var argNames = CompiledModel.ListArguments(); // Start training var sw = new Stopwatch(); for (var iter = 0; iter < epochs; ++iter) { uint samples = 0; train.BatchSize = batchSize; train.Reset(); sw.Restart(); while (train.Next()) { samples += batchSize; var dataBatch = train.GetDataBatch(); // Set data and label dataBatch.Data.CopyTo(args["X"]); dataBatch.Label.CopyTo(args[labelName]); // Compute gradients exec.Forward(true); exec.Backward(); // Update parameters for (var i = 0; i < argNames.Count; ++i) { if (argNames[i] == "X" || argNames[i] == labelName) { continue; } ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i]); } Metric.Update(dataBatch.Label, exec.Outputs[0]); } sw.Stop(); if (validation != null) { validation.BatchSize = batchSize; validation.Reset(); while (validation.Next()) { var dataBatch = validation.GetDataBatch(); dataBatch.Data.CopyTo(args["X"]); dataBatch.Label.CopyTo(args[labelName]); // Forward pass is enough as no gradient is needed when evaluating exec.Forward(false); Metric.Update(dataBatch.Label, exec.Outputs[0]); } } var duration = sw.ElapsedMilliseconds / 1000.0; if (validation == null) { Logging.LG($"Epoch: {iter} {samples / duration} samples/sec Train_Metric: {Metric.Get()}"); } else { Logging.LG($"Epoch: {iter} {samples / duration} samples/sec, Train_Metric: {Metric.Get()}, Val_Metric: {Metric.Get()}"); } } } MXNet.MXNotifyShutdown(); }