Example #1
0
        internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null)
        {
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   arg_names: arg_names,
                                                                   param_names: param_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   sym_gen: sym_gen);

            if (monitor != null)
            {
                executor_manager.InstallMonitor(monitor);
            }

            executor_manager.SetParams(arg_params, aux_params);
            Updater updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            else
            {
                kvstore.SetOptimizer(optimizer);
            }

            if (kvstore != null)
            {
                InitializeKVStore(kvstore: kvstore,
                                  param_arrays: new List <NDArrayList>()
                {
                    executor_manager.ParamArrays
                },
                                  arg_params: arg_params,
                                  param_names: executor_manager.param_names,
                                  update_on_kvstore: update_on_kvstore);
            }

            train_data.Reset();
            for (int epoch = begin_epoch; epoch < end_epoch; epoch++)
            {
                var tic = DateTime.Now;
                eval_metric.Reset();
                int nbatch = 0;
                while (true)
                {
                    bool do_reset = true;
                    while (!train_data.End())
                    {
                        var data_batch = train_data.Next();
                        executor_manager.LoadDataBatch(data_batch);
                        if (monitor != null)
                        {
                            monitor.Tic();
                        }

                        executor_manager.Forward(true);
                        executor_manager.Backward();
                        if (update_on_kvstore)
                        {
                            if (kvstore.Type.Contains("nccl"))
                            {
                                UpdateParamsOnKVStoreNCCL(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                            else
                            {
                                UpdateParamsOnKVStore(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                        }
                        else
                        {
                            UpdateParams(new List <NDArrayList>()
                            {
                                executor_manager.ParamArrays
                            }, new List <NDArrayList>()
                            {
                                executor_manager.GradArrays
                            }, updater, ctx.Length, kvstore, executor_manager.param_names);
                        }

                        if (monitor != null)
                        {
                            monitor.TocPrint();
                        }

                        executor_manager.UpdateMetric(eval_metric, data_batch.Label);
                        nbatch++;
                        if (batch_end_callback != null)
                        {
                            MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric);
                        }

                        if (epoch_size.HasValue && nbatch >= epoch_size.Value)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        Logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size.HasValue)
                    {
                        if (nbatch >= epoch_size.Value)
                        {
                            break;
                        }
                        else
                        {
                            break;
                        }
                    }
                }

                var toc = DateTime.Now;
                Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}");

                if (epoch_end_callback != null || epoch + 1 == end_epoch)
                {
                    executor_manager.CopyTo(arg_params, aux_params);
                }

                MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params);

                if (eval_data != null)
                {
                    eval_metric.Reset();
                    eval_data.Reset();
                    int total_num_batch = 0;
                    int i = 0;
                    while (!eval_data.End())
                    {
                        var eval_batch = eval_data.Next();
                        executor_manager.LoadDataBatch(eval_batch);
                        executor_manager.Forward();
                        executor_manager.UpdateMetric(eval_metric, eval_batch.Label);
                        if (eval_batch_end_callback != null)
                        {
                            MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric);
                        }

                        total_num_batch++;
                    }

                    if (eval_end_callback != null)
                    {
                        MultipleCallbacks(eval_end_callback, epoch, eval_metric);
                    }
                }
            }
        }