示例#1
0
        private static void UpdateParams(
            IList <List <NdArray> > paramArrays,
            IList <List <NdArray> > gradArrays,
            Action <int, NdArray, NdArray> updater,
            int numDevice,
            KvStore kvstore = null)
        {
            for (int index = 0; index < paramArrays.Count; index++)
            {
                var argList  = paramArrays[index];
                var gradList = gradArrays[index];
                if (gradList[0] == null)
                {
                    continue;
                }
                if (kvstore != null)
                {
                    //push gradient, priority is negative index
                    kvstore.Push(index, gradList, priority: -index);
                    //pull back the weights
                    kvstore.Pull(index, argList, priority: -index);
                }

                for (int k = 0; k < argList.Count; k++)
                {
                    var w = argList[k];
                    var g = gradList[k];


                    updater(index * numDevice + k, g, w);
                }
            }
        }
示例#2
0
        private static void InitializeKvstore(KvStore kvstore,
                                              IList <List <NdArray> > paramArrays,
                                              Dictionary <string, NdArray> argParams,
                                              IList <string> paramNames,
                                              bool updateOnKvstore)
        {
            for (int idx = 0; idx < paramArrays.Count; idx++)
            {
                var paramOnDevs = paramArrays[idx];
                kvstore.Init(idx, argParams[paramNames[idx]]);

                if (updateOnKvstore)
                {
                    kvstore.Pull(idx, paramOnDevs, priority: -idx);
                }
            }
        }
示例#3
0
 private static void UpdateParamsOnKvstore(
     IList <List <NdArray> > paramArrays,
     IList <List <NdArray> > gradArrays,
     KvStore kvstore)
 {
     for (int index = 0; index < paramArrays.Count; index++)
     {
         var argList  = paramArrays[index];
         var gradList = gradArrays[index];
         if (gradList[0] == null)
         {
             continue;
         }
         //push gradient, priority is negative index
         kvstore.Push(index, gradList, priority: -index);
         //pull back the weights
         kvstore.Pull(index, argList, priority: -index);
     }
 }
示例#4
0
        private static Tuple <KvStore, bool> CreateKvstore(
            string kvstore, int count, Dictionary <string, NdArray> argParams)
        {
            KvStore kv;

            if (kvstore == null)
            {
                kv = null;
            }
            else
            {
                if (count == 1 && !kvstore.Contains("dist"))
                {
                    kv = null;
                }
                else
                {
                    if (kvstore == "local")
                    {
                        //automatically select a proper local
                        var maxSize = argParams.Select(s => Util.Prod(s.Value.GetShape())).Max();
                        if (maxSize < 1024 * 1024 * 16)
                        {
                            kvstore = "local_update_cpu";
                        }
                        else
                        {
                            kvstore = "local_allreduce_cpu";
                        }
                    }
                    kv = new KvStore(kvstore);
                }
            }

            bool updateOnKvstore = !(kv == null || kv.Type.Contains("local_allreduce"));


            return(Tuple.Create(kv, updateOnKvstore));
        }
示例#5
0
        private static void TrainMultiDevice(Symbol symbol,
                                             IList <Context> ctx,
                                             IList <string> argNames,
                                             IList <string> paramNames,
                                             IList <string> auxNames,
                                             Dictionary <string, NdArray> argParams,
                                             Dictionary <string, NdArray> auxParams,
                                             int beginEpoch,
                                             int endEpoch,
                                             int?epochSize,
                                             Optimizer optimizer,
                                             IDataIter trainData,
                                             IDataIter evalData,
                                             EvalMetric evalMetric,
                                             IList <EpochEndDelegate> epochEndCallback,
                                             IList <BatchEndDelegate> batchEndCallback,
                                             KvStore kvstore, bool updateOnKvstore,
                                             ILog logger,
                                             IList <int> workLoadList,
                                             Monitor monitor,
                                             IList <BatchEndDelegate> evalBatchEndCallback,
                                             SymbolGenerate symGen)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            var executorManager = new DataParallelExecutorManager(symbol: symbol,
                                                                  symGen: symGen,
                                                                  ctx: ctx,
                                                                  trainData: trainData,
                                                                  paramNames: paramNames,
                                                                  argNames: argNames,
                                                                  auxNames: auxNames,
                                                                  workLoadList: workLoadList,
                                                                  logger: logger);


            if (monitor != null)
            {
                executorManager.InstallMonitor(monitor);
            }
            executorManager.SetParams(argParams, auxParams);

            Action <int, NdArray, NdArray> updater = null;

            if (!updateOnKvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            if (kvstore != null)
            {
                InitializeKvstore(kvstore: kvstore,
                                  paramArrays: executorManager.ParamArrays,
                                  argParams: argParams,
                                  paramNames: executorManager.ParamNames,
                                  updateOnKvstore: updateOnKvstore);
            }

            if (updateOnKvstore)
            {
                kvstore?.SetOptimizer(optimizer);
            }

            //Now start training
            for (int epoch = 0; epoch < endEpoch - beginEpoch; epoch++)
            {
                // Training phase
                Stopwatch toc = new Stopwatch();
                toc.Start();
                evalMetric.Reset();
                var nbatch = 0;
                // Iterate over training data.

                while (true)
                {
                    var doReset = true;
                    foreach (var dataBatch in trainData)
                    {
                        executorManager.LoadDataBatch(dataBatch);

                        monitor?.Tic();


                        executorManager.Forward(isTrain: true);
                        executorManager.Backward();



                        if (updateOnKvstore)
                        {
                            UpdateParamsOnKvstore(
                                executorManager.ParamArrays,
                                executorManager.GradArrays,
                                kvstore);
                        }
                        else
                        {
                            UpdateParams(executorManager.ParamArrays,
                                         executorManager.GradArrays,
                                         updater: updater,
                                         numDevice: ctx.Count,
                                         kvstore: kvstore);
                        }
                        monitor?.TocPrint();
                        // evaluate at end, so we can lazy copy
                        executorManager.UpdateMetric(evalMetric, dataBatch.Label);

                        nbatch += 1;
                        //batch callback (for print purpose)

                        if (batchEndCallback != null)
                        {
                            var batchEndParams = new BatchEndParam(epoch: epoch,
                                                                   nbatch: nbatch,
                                                                   evalMetric: evalMetric,
                                                                   locals: Thread.CurrentThread.CurrentCulture);

                            foreach (var call in batchEndCallback)
                            {
                                call(batchEndParams);
                            }
                        }
                        if (epochSize != null && nbatch >= epochSize)
                        {
                            doReset = false;
                            break;
                        }
                    }

                    if (doReset)
                    {
                        logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        trainData.Reset();
                    }

                    if (epochSize == null || nbatch >= epochSize)
                    {
                        break;
                    }
                }


                logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}");

                if (epochEndCallback != null || epoch + 1 == endEpoch)
                {
                    executorManager.copy_to(argParams, auxParams);
                }


                if (epochEndCallback != null)
                {
                    EpochEndParam epochEndParam = new EpochEndParam(epoch, symbol, argParams, auxParams);

                    foreach (var callitem in epochEndCallback)
                    {
                        callitem(epochEndParam);
                    }
                }

                // evaluation
                if (evalData != null)
                {
                    evalMetric.Reset();
                    evalData.Reset();
                    int i = 0;
                    foreach (var eval_batch in evalData)
                    {
                        executorManager.LoadDataBatch(eval_batch);
                        executorManager.Forward(isTrain: false);
                        executorManager.UpdateMetric(evalMetric, eval_batch.Label);

                        if (evalBatchEndCallback != null)
                        {
                            var batchEndParams = new BatchEndParam(epoch: epoch,
                                                                   nbatch: i,
                                                                   evalMetric: evalMetric,
                                                                   locals: Thread.CurrentThread.CurrentCulture);
                            foreach (var call in evalBatchEndCallback)
                            {
                                call(batchEndParams);
                            }
                        }

                        i++;
                    }
                    var nameValue = evalMetric.get_name_value();
                    foreach (var item in nameValue)
                    {
                        logger.Info($"Epoch[{epoch}] Validation-{item.Name}={item.Value:0.000}");
                    }
                    evalData.Reset();
                }
            }
        }