Пример #1
0
 public void UpdateMetric(EvalMetric metric, List <NdArray> labels)
 {
     this._currExecgrp.UpdateMetric(metric, labels);
 }
Пример #2
0
        public void Fit(DataIter train_data, DataIter eval_data = null, string eval_metric = "acc",
                        IEpochEndCallback[] epoch_end_callback  = null, IBatchEndCallback[] batch_end_callback = null,
                        string kvstore      = "local",
                        Optimizer optimizer = null, Dictionary <string, object> optimizer_params = null,
                        IScoreEndCallback[] eval_end_callback       = null,
                        IBatchEndCallback[] eval_batch_end_callback = null, Initializer initializer = null,
                        NDArrayDict arg_params = null,
                        NDArrayDict aux_params = null, bool allow_missing = false, bool force_rebind = false,
                        bool force_init        = false, int begin_epoch   = 0, int?num_epoch = null, EvalMetric validation_metric = null,
                        Monitor monitor        = null, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            object val;
            object name;

            if (optimizer == null)
            {
                optimizer = new SGD();
            }
            Debug.Assert(num_epoch != null, "please specify number of epochs");
            this.Bind(data_shapes: train_data.ProvideData, label_shapes: train_data.ProvideLabel, for_training: true, force_rebind: force_rebind);
            if (monitor != null)
            {
                this.InstallMonitor(monitor);
            }

            this.InitParams(initializer: initializer, arg_params: arg_params, aux_params: aux_params, allow_missing: allow_missing, force_init: force_init);
            this.InitOptimizer(kvstore: kvstore, optimizer: optimizer, optimizer_params: optimizer_params);
            if (validation_metric == null)
            {
                validation_metric = eval_metric;
            }

            var eval_metric_func = EvalMetric.Create(eval_metric, null);

            //###############################################################################
            // training loop
            //###############################################################################
            foreach (var epoch in Enumerable.Range(begin_epoch, num_epoch.Value - begin_epoch))
            {
                var tic = DateTime.Now;
                eval_metric_func.Reset();
                var nbatch          = 0;
                var data_iter       = train_data;
                var end_of_batch    = false;
                var next_data_batch = data_iter.Next();
                Dictionary <string, float> eval_name_vals = new Dictionary <string, float>();
                while (!end_of_batch)
                {
                    var data_batch = next_data_batch;
                    if (monitor != null)
                    {
                        monitor.Tic();
                    }

                    this.ForwardBackward(data_batch);
                    this.Update();

                    UpdateMetric(eval_metric_func, data_batch.Label);

                    try
                    {
                        // pre fetch next batch
                        next_data_batch = data_iter.Next();
                        this.Prepare(next_data_batch, sparse_row_id_fn: sparse_row_id_fn);
                    }
                    catch (StopIteration)
                    {
                        end_of_batch = true;
                    }
                    if (monitor != null)
                    {
                        monitor.TocPrint();
                    }

                    if (end_of_batch)
                    {
                        eval_name_vals = eval_metric_func.GetGlobalNameValue();
                    }

                    if (batch_end_callback != null)
                    {
                        foreach (var callback in batch_end_callback)
                        {
                            callback.Invoke(epoch: epoch, nbatch: nbatch, eval_metric: eval_metric);
                        }
                    }
                    nbatch += 1;
                }

                // one epoch of training is finished
                foreach (var item in eval_name_vals)
                {
                    name = item.Key;
                    val  = item.Value;
                    Logger.Info($"Epoch[{epoch}] Train-{name}={val}");
                }

                var toc = DateTime.Now;

                Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}");
                // sync aux params across devices
                (arg_params, aux_params) = this.GetParams();
                this.SetParams(arg_params, aux_params);
                if (epoch_end_callback != null)
                {
                    foreach (var callback in epoch_end_callback)
                    {
                        callback.Invoke(epoch, this.Symbol, arg_params, aux_params);
                    }
                }
                //----------------------------------------
                // evaluation on validation set
                if (eval_data != null)
                {
                    var res = this.Score(eval_data, validation_metric, score_end_callback: eval_end_callback, batch_end_callback: eval_batch_end_callback, epoch: epoch);
                    //TODO: pull this into default
                    foreach (var item in res)
                    {
                        name = item.Key;
                        val  = item.Value;
                        Logger.Info($"Epoch[{epoch}] Validation-{name}={val}");
                    }
                }
                // end of 1 epoch, reset the data-iter for another epoch
                train_data.Reset();
            }
        }
Пример #3
0
 public abstract void UpdateMetric(EvalMetric eval_metric, NDArrayList labels, bool pre_sliced = false);
Пример #4
0
        public void Invoke(int epoch, int nbatch, EvalMetric eval_metric, FuncArgs locals = null)
        {
            var    count = nbatch;
            float  speed;
            string msg;

            if (last_count > count)
            {
                init = false;
            }

            last_count = count;

            if (init)
            {
                if (count % _frequent == 0)
                {
                    try
                    {
                        speed = (float)Math.Round(_frequent * (float)_batch_size / (DateTime.Now.Ticks - tic));
                    }
                    catch (DivideByZeroException ex)
                    {
                        speed = float.PositiveInfinity;
                    }

                    if (eval_metric != null)
                    {
                        var name_value = eval_metric.GetNameValue();
                        if (_auto_reset)
                        {
                            eval_metric.ResetLocal();
                            msg = string.Format("Epoch[{0}] Batch [{1}-{2}]\tSpeed: {3} samples/sec", epoch,
                                                count - _frequent, count, speed);
                            foreach (var item in name_value)
                            {
                                msg += string.Format("\t {0}={1}", item.Key, item.Value);
                            }

                            Logger.Log(msg);
                        }
                        else
                        {
                            msg = string.Format("Epoch[{0}] Batch [0-{1}]\tSpeed: {2} samples/sec", epoch, count,
                                                speed);
                            foreach (var item in name_value)
                            {
                                msg += string.Format("\t {0}={1}", item.Key, item.Value);
                            }

                            Logger.Log(msg);
                        }
                    }
                    else
                    {
                        Logger.Log(string.Format("Iter[{0}] Batch [{1}]\tSpeed: {} samples/sec", epoch, _batch_size,
                                                 speed));
                    }

                    tic = DateTime.Now.Ticks;
                }
                else
                {
                    init = true;
                    tic  = DateTime.Now.Ticks;
                }
            }
        }
Пример #5
0
 public void Fit(DataIter train_data, DataIter eval_data = null, string eval_metric = "acc",
                 IEpochEndCallback[] epoch_end_callback  = null, IBatchEndCallback[] batch_end_callback = null,
                 string kvstore   = "local",
                 string optimizer = "sgd", Dictionary <string, object> optimizer_params = null,
                 IBatchEndCallback[] eval_end_callback       = null,
                 IBatchEndCallback[] eval_batch_end_callback = null, Initializer initializer = null,
                 NDArrayDict arg_params = null,
                 NDArrayDict aux_params = null, bool allow_missing = false, bool force_rebind = false,
                 bool force_init        = false, int begin_epoch   = 0, int?num_epoch = null, EvalMetric validation_metric = null,
                 Monitor monitor        = null, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
 {
     throw new NotImplementedException();
 }
Пример #6
0
        internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null)
        {
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   arg_names: arg_names,
                                                                   param_names: param_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   sym_gen: sym_gen);

            if (monitor != null)
            {
                executor_manager.InstallMonitor(monitor);
            }

            executor_manager.SetParams(arg_params, aux_params);
            Updater updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            else
            {
                kvstore.SetOptimizer(optimizer);
            }

            if (kvstore != null)
            {
                InitializeKVStore(kvstore: kvstore,
                                  param_arrays: new List <NDArrayList>()
                {
                    executor_manager.ParamArrays
                },
                                  arg_params: arg_params,
                                  param_names: executor_manager.param_names,
                                  update_on_kvstore: update_on_kvstore);
            }

            train_data.Reset();
            for (int epoch = begin_epoch; epoch < end_epoch; epoch++)
            {
                var tic = DateTime.Now;
                eval_metric.Reset();
                int nbatch = 0;
                while (true)
                {
                    bool do_reset = true;
                    while (!train_data.End())
                    {
                        var data_batch = train_data.Next();
                        executor_manager.LoadDataBatch(data_batch);
                        if (monitor != null)
                        {
                            monitor.Tic();
                        }

                        executor_manager.Forward(true);
                        executor_manager.Backward();
                        if (update_on_kvstore)
                        {
                            if (kvstore.Type.Contains("nccl"))
                            {
                                UpdateParamsOnKVStoreNCCL(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                            else
                            {
                                UpdateParamsOnKVStore(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                        }
                        else
                        {
                            UpdateParams(new List <NDArrayList>()
                            {
                                executor_manager.ParamArrays
                            }, new List <NDArrayList>()
                            {
                                executor_manager.GradArrays
                            }, updater, ctx.Length, kvstore, executor_manager.param_names);
                        }

                        if (monitor != null)
                        {
                            monitor.TocPrint();
                        }

                        executor_manager.UpdateMetric(eval_metric, data_batch.Label);
                        nbatch++;
                        if (batch_end_callback != null)
                        {
                            MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric);
                        }

                        if (epoch_size.HasValue && nbatch >= epoch_size.Value)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        Logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size.HasValue)
                    {
                        if (nbatch >= epoch_size.Value)
                        {
                            break;
                        }
                        else
                        {
                            break;
                        }
                    }
                }

                var toc = DateTime.Now;
                Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}");

                if (epoch_end_callback != null || epoch + 1 == end_epoch)
                {
                    executor_manager.CopyTo(arg_params, aux_params);
                }

                MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params);

                if (eval_data != null)
                {
                    eval_metric.Reset();
                    eval_data.Reset();
                    int total_num_batch = 0;
                    int i = 0;
                    while (!eval_data.End())
                    {
                        var eval_batch = eval_data.Next();
                        executor_manager.LoadDataBatch(eval_batch);
                        executor_manager.Forward();
                        executor_manager.UpdateMetric(eval_metric, eval_batch.Label);
                        if (eval_batch_end_callback != null)
                        {
                            MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric);
                        }

                        total_num_batch++;
                    }

                    if (eval_end_callback != null)
                    {
                        MultipleCallbacks(eval_end_callback, epoch, eval_metric);
                    }
                }
            }
        }
Пример #7
0
 public override void UpdateMetric(EvalMetric eval_metric, NDArrayList labels, bool pre_sliced = false)
 {
     _exec_group.UpdateMetric(eval_metric, labels, pre_sliced);
 }
Пример #8
0
 public void update_metric(EvalMetric metric, List <NDArray> labels)
 {
     this._curr_execgrp.update_metric(metric, labels);
 }
 public void UpdateMetric(EvalMetric eval_metric, NDArrayList labels, bool pre_sliced = false)
 {
     curr_execgrp.UpdateMetric(eval_metric, labels, pre_sliced);
 }
Пример #10
0
        private static void TrainMultiDevice(Symbol symbol,
                                             IList <Context> ctx,
                                             IList <string> argNames,
                                             IList <string> paramNames,
                                             IList <string> auxNames,
                                             Dictionary <string, NdArray> argParams,
                                             Dictionary <string, NdArray> auxParams,
                                             int beginEpoch,
                                             int endEpoch,
                                             int?epochSize,
                                             Optimizer optimizer,
                                             IDataIter trainData,
                                             IDataIter evalData,
                                             EvalMetric evalMetric,
                                             IList <EpochEndDelegate> epochEndCallback,
                                             IList <BatchEndDelegate> batchEndCallback,
                                             KvStore kvstore, bool updateOnKvstore,
                                             ILog logger,
                                             IList <int> workLoadList,
                                             Monitor monitor,
                                             IList <BatchEndDelegate> evalBatchEndCallback,
                                             SymbolGenerate symGen)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            var executorManager = new DataParallelExecutorManager(symbol: symbol,
                                                                  symGen: symGen,
                                                                  ctx: ctx,
                                                                  trainData: trainData,
                                                                  paramNames: paramNames,
                                                                  argNames: argNames,
                                                                  auxNames: auxNames,
                                                                  workLoadList: workLoadList,
                                                                  logger: logger);


            if (monitor != null)
            {
                executorManager.InstallMonitor(monitor);
            }
            executorManager.SetParams(argParams, auxParams);

            Action <int, NdArray, NdArray> updater = null;

            if (!updateOnKvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            if (kvstore != null)
            {
                InitializeKvstore(kvstore: kvstore,
                                  paramArrays: executorManager.ParamArrays,
                                  argParams: argParams,
                                  paramNames: executorManager.ParamNames,
                                  updateOnKvstore: updateOnKvstore);
            }

            if (updateOnKvstore)
            {
                kvstore?.SetOptimizer(optimizer);
            }

            //Now start training
            for (int epoch = 0; epoch < endEpoch - beginEpoch; epoch++)
            {
                // Training phase
                Stopwatch toc = new Stopwatch();
                toc.Start();
                evalMetric.Reset();
                var nbatch = 0;
                // Iterate over training data.

                while (true)
                {
                    var doReset = true;
                    foreach (var dataBatch in trainData)
                    {
                        executorManager.LoadDataBatch(dataBatch);

                        monitor?.Tic();


                        executorManager.Forward(isTrain: true);
                        executorManager.Backward();



                        if (updateOnKvstore)
                        {
                            UpdateParamsOnKvstore(
                                executorManager.ParamArrays,
                                executorManager.GradArrays,
                                kvstore);
                        }
                        else
                        {
                            UpdateParams(executorManager.ParamArrays,
                                         executorManager.GradArrays,
                                         updater: updater,
                                         numDevice: ctx.Count,
                                         kvstore: kvstore);
                        }
                        monitor?.TocPrint();
                        // evaluate at end, so we can lazy copy
                        executorManager.UpdateMetric(evalMetric, dataBatch.Label);

                        nbatch += 1;
                        //batch callback (for print purpose)

                        if (batchEndCallback != null)
                        {
                            var batchEndParams = new BatchEndParam(epoch: epoch,
                                                                   nbatch: nbatch,
                                                                   evalMetric: evalMetric,
                                                                   locals: Thread.CurrentThread.CurrentCulture);

                            foreach (var call in batchEndCallback)
                            {
                                call(batchEndParams);
                            }
                        }
                        if (epochSize != null && nbatch >= epochSize)
                        {
                            doReset = false;
                            break;
                        }
                    }

                    if (doReset)
                    {
                        logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        trainData.Reset();
                    }

                    if (epochSize == null || nbatch >= epochSize)
                    {
                        break;
                    }
                }


                logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}");

                if (epochEndCallback != null || epoch + 1 == endEpoch)
                {
                    executorManager.copy_to(argParams, auxParams);
                }


                if (epochEndCallback != null)
                {
                    EpochEndParam epochEndParam = new EpochEndParam(epoch, symbol, argParams, auxParams);

                    foreach (var callitem in epochEndCallback)
                    {
                        callitem(epochEndParam);
                    }
                }

                // evaluation
                if (evalData != null)
                {
                    evalMetric.Reset();
                    evalData.Reset();
                    int i = 0;
                    foreach (var eval_batch in evalData)
                    {
                        executorManager.LoadDataBatch(eval_batch);
                        executorManager.Forward(isTrain: false);
                        executorManager.UpdateMetric(evalMetric, eval_batch.Label);

                        if (evalBatchEndCallback != null)
                        {
                            var batchEndParams = new BatchEndParam(epoch: epoch,
                                                                   nbatch: i,
                                                                   evalMetric: evalMetric,
                                                                   locals: Thread.CurrentThread.CurrentCulture);
                            foreach (var call in evalBatchEndCallback)
                            {
                                call(batchEndParams);
                            }
                        }

                        i++;
                    }
                    var nameValue = evalMetric.get_name_value();
                    foreach (var item in nameValue)
                    {
                        logger.Info($"Epoch[{epoch}] Validation-{item.Name}={item.Value:0.000}");
                    }
                    evalData.Reset();
                }
            }
        }
Пример #11
0
        public void Fit(IDataIter trainData,
                        IDataIter evalData,
                        EvalMetric evalMetric = null,
                        IList <EpochEndDelegate> epochEndCallback = null,
                        IList <BatchEndDelegate> batchEndCallback = null,
                        string kvstoreInput      = "local",
                        ILog logger              = null,
                        IList <int> workLoadList = null, Monitor monitor = null,
                        IList <BatchEndDelegate> evalBatchEndCallback = null
                        )
        {
            var data = trainData;

            if (this._symGen != null)
            {
                this._symbol = this._symGen(data.DefaultBucketKey);
                this.CheckArguments();
            }
            this._kwargs["sym"] = this._symbol;

            var initParamsTemp = this.InitParams(data.ProvideData.Concat(data.ProvideLabel).ToDictionary(x => x.Key, y => y.Value));


            var argNames   = initParamsTemp.Item1;
            var paramNames = initParamsTemp.Item2;
            var auxNames   = initParamsTemp.Item3;

            if (evalMetric == null)
            {
                evalMetric = "acc";
            }

            //create kvstore
            var createKvstoreTemp = CreateKvstore(kvstoreInput, _ctx.Count, ArgParams);
            var kvstore           = createKvstoreTemp.Item1;
            var updateOnKvstore   = createKvstoreTemp.Item2;

            var paramIdx2Name = new Dictionary <int, string>();

            if (updateOnKvstore)
            {
                paramIdx2Name = paramNames.Select((x, i) => new { i = i, x = x }).ToDictionary(k => k.i, v => v.x);
            }
            else
            {
                for (int i = 0; i < paramNames.Count; i++)
                {
                    for (int k = 0; k < _ctx.Count; k++)
                    {
                        paramIdx2Name[i * _ctx.Count + k] = paramNames[i];
                    }
                }
            }
            _kwargs["param_idx2name"] = paramIdx2Name;

            //(TODO)init optmizer

            TrainMultiDevice(this._symbol, this._ctx, argNames, paramNames, auxNames,
                             this.ArgParams, this.AuxParams,
                             beginEpoch: this._beginEpoch, endEpoch: this._numEpoch,
                             epochSize: this._epochSize,
                             optimizer: _optimizer,
                             trainData: data, evalData: evalData,
                             evalMetric: evalMetric,
                             epochEndCallback: epochEndCallback,
                             batchEndCallback: batchEndCallback,
                             kvstore: kvstore, updateOnKvstore: updateOnKvstore,
                             logger: logger, workLoadList: workLoadList, monitor: monitor,
                             evalBatchEndCallback: evalBatchEndCallback,
                             symGen: this._symGen);
        }
Пример #12
0
        private static void _train_multi_device(Symbol symbol, List <Context> ctx, List <string> arg_names,
                                                List <string> param_names, List <string> aux_names, Dictionary <string, NDArray> arg_params,
                                                Dictionary <string, NDArray> aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer,
                                                IDataIter train_data, IDataIter eval_data, EvalMetric eval_metric, List <Action> epoch_end_callback,
                                                List <Action <BatchEndParam> > batch_end_callback, KVStore kvstore, bool update_on_kvstore, ILog logger, List <int> work_load_list,
                                                Monitor monitor, Action eval_batch_end_callback, SymbolGenerate sym_gen)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   sym_gen: sym_gen,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   param_names: param_names,
                                                                   arg_names: arg_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   logger: logger);


            if (monitor != null)
            {
                executor_manager.install_monitor(monitor);
            }
            executor_manager.set_params(arg_params, aux_params);

            Action <int, NDArray, NDArray> updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.get_updater(optimizer);
            }
            if (kvstore != null)
            {
                _initialize_kvstore(kvstore: kvstore,
                                    param_arrays: executor_manager.param_arrays,
                                    arg_params: arg_params,
                                    param_names: executor_manager.param_names,
                                    update_on_kvstore: update_on_kvstore);
            }

            if (update_on_kvstore)
            {
                kvstore.set_optimizer(optimizer);
            }

            //Now start training
            for (int epoch = 0; epoch < end_epoch - begin_epoch; epoch++)
            {
                // Training phase
                Stopwatch toc = new Stopwatch();
                toc.Start();
                eval_metric.Reset();
                var nbatch = 0;
                // Iterate over training data.

                while (true)
                {
                    var do_reset = true;
                    foreach (var data_batch in train_data)
                    {
                        executor_manager.load_data_batch(data_batch);

                        monitor?.Tic();


                        executor_manager.Forward(is_train: true);
                        executor_manager.Backward();



                        if (update_on_kvstore)
                        {
                            _update_params_on_kvstore(
                                executor_manager.param_arrays,
                                executor_manager.grad_arrays,
                                kvstore);
                        }
                        else
                        {
                            _update_params(executor_manager.param_arrays,
                                           executor_manager.grad_arrays,
                                           updater: updater,
                                           num_device: ctx.Count,
                                           kvstore: kvstore);
                        }
                        monitor?.toc_print();
                        // evaluate at end, so we can lazy copy
                        executor_manager.update_metric(eval_metric, data_batch.label);

                        nbatch += 1;
                        //batch callback (for print purpose)

                        if (batch_end_callback != null)
                        {
                            var batch_end_params = new BatchEndParam(epoch: epoch,
                                                                     nbatch: nbatch,
                                                                     eval_metric: eval_metric,
                                                                     locals: Thread.CurrentThread.CurrentCulture);

                            foreach (var call in batch_end_callback)
                            {
                                call(batch_end_params);
                            }
                        }
                        if (epoch_size != null && nbatch >= epoch_size)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size == null || nbatch >= epoch_size)
                    {
                        break;
                    }
                }


                logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}");
            }
        }