internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null) { var executor_manager = new DataParallelExecutorManager(symbol: symbol, ctx: ctx, train_data: train_data, arg_names: arg_names, param_names: param_names, aux_names: aux_names, work_load_list: work_load_list, sym_gen: sym_gen); if (monitor != null) { executor_manager.InstallMonitor(monitor); } executor_manager.SetParams(arg_params, aux_params); Updater updater = null; if (!update_on_kvstore) { updater = Optimizer.GetUpdater(optimizer); } else { kvstore.SetOptimizer(optimizer); } if (kvstore != null) { InitializeKVStore(kvstore: kvstore, param_arrays: new List <NDArrayList>() { executor_manager.ParamArrays }, arg_params: arg_params, param_names: executor_manager.param_names, update_on_kvstore: update_on_kvstore); } train_data.Reset(); for (int epoch = begin_epoch; epoch < end_epoch; epoch++) { var tic = DateTime.Now; eval_metric.Reset(); int nbatch = 0; while (true) { bool do_reset = true; while (!train_data.End()) { var data_batch = train_data.Next(); executor_manager.LoadDataBatch(data_batch); if (monitor != null) { monitor.Tic(); } executor_manager.Forward(true); executor_manager.Backward(); if (update_on_kvstore) { if (kvstore.Type.Contains("nccl")) { UpdateParamsOnKVStoreNCCL(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } else { UpdateParamsOnKVStore(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, kvstore, executor_manager.param_names); } } else { UpdateParams(new List <NDArrayList>() { executor_manager.ParamArrays }, new List <NDArrayList>() { executor_manager.GradArrays }, updater, ctx.Length, kvstore, executor_manager.param_names); } if (monitor != null) { monitor.TocPrint(); } executor_manager.UpdateMetric(eval_metric, data_batch.Label); nbatch++; if (batch_end_callback != null) { MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric); } if (epoch_size.HasValue && nbatch >= epoch_size.Value) { do_reset = false; break; } } if (do_reset) { Logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); train_data.Reset(); } if (epoch_size.HasValue) { if (nbatch >= epoch_size.Value) { break; } else { break; } } } var toc = DateTime.Now; Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}"); if (epoch_end_callback != null || epoch + 1 == end_epoch) { executor_manager.CopyTo(arg_params, aux_params); } MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params); if (eval_data != null) { eval_metric.Reset(); eval_data.Reset(); int total_num_batch = 0; int i = 0; while (!eval_data.End()) { var eval_batch = eval_data.Next(); executor_manager.LoadDataBatch(eval_batch); executor_manager.Forward(); executor_manager.UpdateMetric(eval_metric, eval_batch.Label); if (eval_batch_end_callback != null) { MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric); } total_num_batch++; } if (eval_end_callback != null) { MultipleCallbacks(eval_end_callback, epoch, eval_metric); } } } }