public Dictionary <string, float> Score(DataIter eval_data, EvalMetric eval_metric, int?num_batch = null, IBatchEndCallback[] batch_end_callback = null, IScoreEndCallback[] score_end_callback = null, bool reset = true, int epoch = 0, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null) { if (!Binded && !ParamsInitialized) { throw new Exception("Module not binded and param initialized"); } eval_metric.Reset(); var actual_num_batch = 0; while (eval_data.End()) { if (num_batch.HasValue && eval_data.Cursor == num_batch.Value) { break; } var eval_batch = eval_data.Next(); Prepare(eval_batch, sparse_row_id_fn); Forward(eval_batch, false); UpdateMetric(eval_metric, eval_batch.Label, true); if (batch_end_callback != null) { foreach (var callback in batch_end_callback) { callback.Invoke(epoch, eval_data.Cursor, eval_metric); } } actual_num_batch++; } if (score_end_callback != null) { foreach (var callback in score_end_callback) { callback.Invoke(epoch, actual_num_batch, eval_metric, new FuncArgs()); } } return(eval_metric.GetNameValue()); }
public IEnumerable <(NDArrayList, int, DataBatch)> IterPredict(DataIter eval_data, int?num_batch = null, bool reset = true, int epoch = 0, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null) { if (!Binded && !ParamsInitialized) { throw new Exception("Module not binded and param initialized"); } if (reset) { eval_data.Reset(); } while (eval_data.End()) { if (num_batch.HasValue && eval_data.Cursor == num_batch.Value) { break; } var eval_batch = eval_data.Next(); Prepare(eval_batch, sparse_row_id_fn); Forward(eval_batch, false); var pad = eval_batch.Pad.Value; var outputs = new NDArrayList(); foreach (var list in GetOutputs()) { foreach (var @out in list) { outputs.Add(@out[$"0:{@out.Shape[0] - pad}"]); } } yield return(outputs.ToArray(), eval_data.Cursor, eval_batch); } }
internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null) { var executor_manager = new DataParallelExecutorManager(symbol: symbol, ctx: ctx, train_data: train_data, arg_names: arg_names, param_names: param_names, aux_names: aux_names, work_load_list: work_load_list, sym_gen: sym_gen); if (monitor != null) { executor_manager.InstallMonitor(monitor); } executor_manager.SetParams(arg_params, aux_params); Updater updater = null; if (!update_on_kvstore) { updater = Optimizer.GetUpdater(optimizer); } else { kvstore.SetOptimizer(optimizer); } if (kvstore != null) { InitializeKVStore(kvstore: kvstore, param_arrays: new List <NDArrayList>() { executor_manager.ParamArrays }, arg_params: arg_params, param_names: executor_manager.param_names, update_on_kvstore: update_on_kvstore); } train_data.Reset(); for (int epoch = begin_epoch; epoch < end_epoch; epoch++) { var tic = DateTime.Now; eval_metric.Reset(); int nbatch = 0; while (true) { bool do_reset = true; while (!train_data.End()) { var data_batch = train_data.Next(); executor_manager.LoadDataBatch(data_batch); if (monitor != null) { monitor.Tic(); } executor_manager.Forward(true); executor_manager.Backward(); if (update_on_kvstore) { if (kvstore.Type.Contains("nccl")) { UpdateParamsOnKVStoreNCCL(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } else { UpdateParamsOnKVStore(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } } else { UpdateParams(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, updater, ctx.Length, kvstore, executor_manager.param_names); } if (monitor != null) { monitor.TocPrint(); } executor_manager.UpdateMetric(eval_metric, data_batch.Label); nbatch++; if (batch_end_callback != null) { MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric); } if (epoch_size.HasValue && nbatch >= epoch_size.Value) { do_reset = false; break; } } if (do_reset) { Logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); train_data.Reset(); } if (epoch_size.HasValue) { if (nbatch >= epoch_size.Value) { break; } else { break; } } } var toc = DateTime.Now; Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}"); if (epoch_end_callback != null || epoch + 1 == end_epoch) { executor_manager.CopyTo(arg_params, aux_params); } MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params); if (eval_data != null) { eval_metric.Reset(); eval_data.Reset(); int total_num_batch = 0; int i = 0; while (!eval_data.End()) { var eval_batch = eval_data.Next(); executor_manager.LoadDataBatch(eval_batch); executor_manager.Forward(); executor_manager.UpdateMetric(eval_metric, eval_batch.Label); if (eval_batch_end_callback != null) { MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric); } total_num_batch++; } if (eval_end_callback != null) { MultipleCallbacks(eval_end_callback, epoch, eval_metric); } } } }
public List <NDArrayList> Predict(DataIter eval_data, int?num_batch = null, bool merge_batches = true, bool reset = true, bool always_output_list = true, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null) { if (!Binded && !ParamsInitialized) { throw new Exception("Module not binded and param initialized"); } if (reset) { eval_data.Reset(); } var output_list = new List <NDArrayList>(); var output_list2 = new NDArrayList(); while (eval_data.End()) { if (num_batch.HasValue && eval_data.Cursor == num_batch.Value) { break; } var eval_batch = eval_data.Next(); Prepare(eval_batch, sparse_row_id_fn); Forward(eval_batch, false); var pad = eval_batch.Pad.Value; var outputs = new NDArrayList(); foreach (var list in GetOutputs()) { foreach (var @out in list) { outputs.Add(@out[$"0:{@out.Shape[0] - pad}"].Copy()); } } output_list.Add(outputs.ToArray()); } if (output_list.Count == 0) { return(output_list); } if (merge_batches) { var num_outputs = output_list[0].Length; foreach (var @out in output_list) { if (@out.Length != num_outputs) { throw new Exception("Cannot merge batches, as num of outputs is not the same " + "in mini-batches. Maybe bucketing is used?"); } output_list2.Add(nd.Concat(@out)); } return(new List <NDArrayList> { output_list2.ToArray() }); } return(output_list); }