Esempio n. 1
0
        private FeedForward(Symbol symbol = null,
                            SymbolGenerate symbolGenerate = null,
                            List <Context> ctx            = null,
                            int numEpoch            = 0,
                            int?epochSize           = null,
                            Optimizer optimizer     = null,
                            Initializer initializer = null,
                            Dictionary <string, NdArray> argParams = null,
                            Dictionary <string, NdArray> auxParams = null,
                            bool allowExtraParams = false,
                            int beginEpoch        = 0)
        {
            this._symbol = symbol;
            this._symGen = symbolGenerate;

            if (initializer == null)
            {
                initializer = new Uniform(0.01f);
            }
            this._initializer      = initializer;
            this.ArgParams         = argParams;
            this.AuxParams         = auxParams;
            this._allowExtraParams = allowExtraParams;

            this._argumentChecked = false;
            if (this._symGen == null)
            {
                this.CheckArguments();
            }

            if (ctx == null)
            {
                ctx = new List <mxnet.csharp.Context>()
                {
                    new Context(DeviceType.KCpu, 0)
                };
            }

            this._ctx = ctx;
            // training parameters
            this._numEpoch = numEpoch;

            this._kwargs = new Dictionary <string, object>();

            if (optimizer == null)
            {
                optimizer = "ccsgd";
            }

            this._optimizer = optimizer;
            // internal helper state;
            this._predExec   = null;
            this._beginEpoch = beginEpoch;
            this._epochSize  = epochSize;
        }
Esempio n. 2
0
        public DataParallelExecutorManager(Symbol symbol,
                                           SymbolGenerate symGen,
                                           IList <Context> ctx,
                                           IDataIter trainData,
                                           IList <string> paramNames,
                                           IList <string> argNames,
                                           IList <string> auxNames,
                                           IList <int> workLoadList,
                                           ILog logger)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            this._logger = logger;

            var numDevice = ctx.Count;

            logger.Info("Start training with " + string.Join("", ctx));

            if (workLoadList == null)
            {
                workLoadList = Enumerable.Repeat(1, numDevice).ToList();
            }
            Util.Assert(workLoadList.Count == numDevice, "Invalid settings for work load. ");

            var slices = SplitInputSlice(trainData.BatchSize, workLoadList);

            this._slices = slices;

            this._argNames  = argNames;
            this.ParamNames = paramNames;
            this._auxNames  = auxNames;
            this._ctx       = ctx;

            this._execgrp = new DataParallelExecutorGroup(symbol, this._argNames, this.ParamNames, this._ctx,
                                                          this._slices, trainData);

            this._symbol = symbol;

            this._symGen      = symGen;
            this._currExecgrp = null;
            // this is set when data is loaded
            if (this._symGen != null)
            {
                this._execgrpBucket = new Dictionary <string, DataParallelExecutorGroup>()
                {
                    { trainData.DefaultBucketKey, this._execgrp }
                };
            }
        }
Esempio n. 3
0
        public DataParallelExecutorManager(Symbol symbol, SymbolGenerate sym_gen, List <Context> ctx,
                                           IDataIter train_data, List <string> param_names, List <string> arg_names,
                                           List <string> aux_names, List <int> work_load_list, ILog logger)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            this._logger = logger;

            var num_device = ctx.Count;

            logger.Info("Start training with " + string.Join("", ctx));

            if (work_load_list == null)
            {
                work_load_list = Enumerable.Repeat(1, num_device).ToList();
            }
            Util.Assert(work_load_list.Count == num_device, "Invalid settings for work load. ");

            var slices = _split_input_slice(train_data.batch_size, work_load_list);

            this._slices = slices;

            this._arg_names  = arg_names;
            this.param_names = param_names;
            this._aux_names  = aux_names;
            this._ctx        = ctx;

            this._execgrp = new DataParallelExecutorGroup(symbol, this._arg_names, this.param_names, this._ctx,
                                                          this._slices, train_data);

            this._symbol = symbol;

            this._sym_gen      = sym_gen;
            this._curr_execgrp = null;
            // this is set when data is loaded
            if (this._sym_gen != null)
            {
                this._execgrp_bucket = new Dictionary <string, DataParallelExecutorGroup>()
                {
                    { train_data.default_bucket_key, this._execgrp }
                };
            }
        }
Esempio n. 4
0
 public FeedForward(SymbolGenerate symbolGenerate = null,
                    List <Context> ctx            = null,
                    int numEpoch            = 0,
                    int?epochSize           = null,
                    Optimizer optimizer     = null,
                    Initializer initializer = null,
                    Dictionary <string, NdArray> argParams = null,
                    Dictionary <string, NdArray> auxParams = null,
                    bool allowExtraParams = false,
                    int beginEpoch        = 0)
     : this(null,
            symbolGenerate,
            ctx,
            numEpoch,
            epochSize,
            optimizer,
            initializer,
            argParams,
            auxParams,
            allowExtraParams,
            beginEpoch)
 {
 }
Esempio n. 5
0
 public FeedForward(SymbolGenerate symbol_generate = null,
                    List <Context> ctx             = null,
                    int num_epoch           = 0,
                    int?epoch_size          = null,
                    Optimizer optimizer     = null,
                    Initializer initializer = null,
                    Dictionary <string, NDArray> arg_params = null,
                    Dictionary <string, NDArray> aux_params = null,
                    bool allow_extra_params = false,
                    int begin_epoch         = 0)
     : this(null,
            symbol_generate,
            ctx,
            num_epoch,
            epoch_size,
            optimizer,
            initializer,
            arg_params,
            aux_params,
            allow_extra_params,
            begin_epoch)
 {
 }
Esempio n. 6
0
        private static void TrainMultiDevice(Symbol symbol,
                                             IList <Context> ctx,
                                             IList <string> argNames,
                                             IList <string> paramNames,
                                             IList <string> auxNames,
                                             Dictionary <string, NdArray> argParams,
                                             Dictionary <string, NdArray> auxParams,
                                             int beginEpoch,
                                             int endEpoch,
                                             int?epochSize,
                                             Optimizer optimizer,
                                             IDataIter trainData,
                                             IDataIter evalData,
                                             EvalMetric evalMetric,
                                             IList <EpochEndDelegate> epochEndCallback,
                                             IList <BatchEndDelegate> batchEndCallback,
                                             KvStore kvstore, bool updateOnKvstore,
                                             ILog logger,
                                             IList <int> workLoadList,
                                             Monitor monitor,
                                             IList <BatchEndDelegate> evalBatchEndCallback,
                                             SymbolGenerate symGen)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            var executorManager = new DataParallelExecutorManager(symbol: symbol,
                                                                  symGen: symGen,
                                                                  ctx: ctx,
                                                                  trainData: trainData,
                                                                  paramNames: paramNames,
                                                                  argNames: argNames,
                                                                  auxNames: auxNames,
                                                                  workLoadList: workLoadList,
                                                                  logger: logger);


            if (monitor != null)
            {
                executorManager.InstallMonitor(monitor);
            }
            executorManager.SetParams(argParams, auxParams);

            Action <int, NdArray, NdArray> updater = null;

            if (!updateOnKvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            if (kvstore != null)
            {
                InitializeKvstore(kvstore: kvstore,
                                  paramArrays: executorManager.ParamArrays,
                                  argParams: argParams,
                                  paramNames: executorManager.ParamNames,
                                  updateOnKvstore: updateOnKvstore);
            }

            if (updateOnKvstore)
            {
                kvstore?.SetOptimizer(optimizer);
            }

            //Now start training
            for (int epoch = 0; epoch < endEpoch - beginEpoch; epoch++)
            {
                // Training phase
                Stopwatch toc = new Stopwatch();
                toc.Start();
                evalMetric.Reset();
                var nbatch = 0;
                // Iterate over training data.

                while (true)
                {
                    var doReset = true;
                    foreach (var dataBatch in trainData)
                    {
                        executorManager.LoadDataBatch(dataBatch);

                        monitor?.Tic();


                        executorManager.Forward(isTrain: true);
                        executorManager.Backward();



                        if (updateOnKvstore)
                        {
                            UpdateParamsOnKvstore(
                                executorManager.ParamArrays,
                                executorManager.GradArrays,
                                kvstore);
                        }
                        else
                        {
                            UpdateParams(executorManager.ParamArrays,
                                         executorManager.GradArrays,
                                         updater: updater,
                                         numDevice: ctx.Count,
                                         kvstore: kvstore);
                        }
                        monitor?.TocPrint();
                        // evaluate at end, so we can lazy copy
                        executorManager.UpdateMetric(evalMetric, dataBatch.Label);

                        nbatch += 1;
                        //batch callback (for print purpose)

                        if (batchEndCallback != null)
                        {
                            var batchEndParams = new BatchEndParam(epoch: epoch,
                                                                   nbatch: nbatch,
                                                                   evalMetric: evalMetric,
                                                                   locals: Thread.CurrentThread.CurrentCulture);

                            foreach (var call in batchEndCallback)
                            {
                                call(batchEndParams);
                            }
                        }
                        if (epochSize != null && nbatch >= epochSize)
                        {
                            doReset = false;
                            break;
                        }
                    }

                    if (doReset)
                    {
                        logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        trainData.Reset();
                    }

                    if (epochSize == null || nbatch >= epochSize)
                    {
                        break;
                    }
                }


                logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}");

                if (epochEndCallback != null || epoch + 1 == endEpoch)
                {
                    executorManager.copy_to(argParams, auxParams);
                }


                if (epochEndCallback != null)
                {
                    EpochEndParam epochEndParam = new EpochEndParam(epoch, symbol, argParams, auxParams);

                    foreach (var callitem in epochEndCallback)
                    {
                        callitem(epochEndParam);
                    }
                }

                // evaluation
                if (evalData != null)
                {
                    evalMetric.Reset();
                    evalData.Reset();
                    int i = 0;
                    foreach (var eval_batch in evalData)
                    {
                        executorManager.LoadDataBatch(eval_batch);
                        executorManager.Forward(isTrain: false);
                        executorManager.UpdateMetric(evalMetric, eval_batch.Label);

                        if (evalBatchEndCallback != null)
                        {
                            var batchEndParams = new BatchEndParam(epoch: epoch,
                                                                   nbatch: i,
                                                                   evalMetric: evalMetric,
                                                                   locals: Thread.CurrentThread.CurrentCulture);
                            foreach (var call in evalBatchEndCallback)
                            {
                                call(batchEndParams);
                            }
                        }

                        i++;
                    }
                    var nameValue = evalMetric.get_name_value();
                    foreach (var item in nameValue)
                    {
                        logger.Info($"Epoch[{epoch}] Validation-{item.Name}={item.Value:0.000}");
                    }
                    evalData.Reset();
                }
            }
        }
Esempio n. 7
0
        private static void _train_multi_device(Symbol symbol, List <Context> ctx, List <string> arg_names,
                                                List <string> param_names, List <string> aux_names, Dictionary <string, NDArray> arg_params,
                                                Dictionary <string, NDArray> aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer,
                                                IDataIter train_data, IDataIter eval_data, EvalMetric eval_metric, List <Action> epoch_end_callback,
                                                List <Action <BatchEndParam> > batch_end_callback, KVStore kvstore, bool update_on_kvstore, ILog logger, List <int> work_load_list,
                                                Monitor monitor, Action eval_batch_end_callback, SymbolGenerate sym_gen)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   sym_gen: sym_gen,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   param_names: param_names,
                                                                   arg_names: arg_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   logger: logger);


            if (monitor != null)
            {
                executor_manager.install_monitor(monitor);
            }
            executor_manager.set_params(arg_params, aux_params);

            Action <int, NDArray, NDArray> updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.get_updater(optimizer);
            }
            if (kvstore != null)
            {
                _initialize_kvstore(kvstore: kvstore,
                                    param_arrays: executor_manager.param_arrays,
                                    arg_params: arg_params,
                                    param_names: executor_manager.param_names,
                                    update_on_kvstore: update_on_kvstore);
            }

            if (update_on_kvstore)
            {
                kvstore.set_optimizer(optimizer);
            }

            //Now start training
            for (int epoch = 0; epoch < end_epoch - begin_epoch; epoch++)
            {
                // Training phase
                Stopwatch toc = new Stopwatch();
                toc.Start();
                eval_metric.Reset();
                var nbatch = 0;
                // Iterate over training data.

                while (true)
                {
                    var do_reset = true;
                    foreach (var data_batch in train_data)
                    {
                        executor_manager.load_data_batch(data_batch);

                        monitor?.Tic();


                        executor_manager.Forward(is_train: true);
                        executor_manager.Backward();



                        if (update_on_kvstore)
                        {
                            _update_params_on_kvstore(
                                executor_manager.param_arrays,
                                executor_manager.grad_arrays,
                                kvstore);
                        }
                        else
                        {
                            _update_params(executor_manager.param_arrays,
                                           executor_manager.grad_arrays,
                                           updater: updater,
                                           num_device: ctx.Count,
                                           kvstore: kvstore);
                        }
                        monitor?.toc_print();
                        // evaluate at end, so we can lazy copy
                        executor_manager.update_metric(eval_metric, data_batch.label);

                        nbatch += 1;
                        //batch callback (for print purpose)

                        if (batch_end_callback != null)
                        {
                            var batch_end_params = new BatchEndParam(epoch: epoch,
                                                                     nbatch: nbatch,
                                                                     eval_metric: eval_metric,
                                                                     locals: Thread.CurrentThread.CurrentCulture);

                            foreach (var call in batch_end_callback)
                            {
                                call(batch_end_params);
                            }
                        }
                        if (epoch_size != null && nbatch >= epoch_size)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size == null || nbatch >= epoch_size)
                    {
                        break;
                    }
                }


                logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}");
            }
        }