private static void UpdateParams( IList <List <NdArray> > paramArrays, IList <List <NdArray> > gradArrays, Action <int, NdArray, NdArray> updater, int numDevice, KvStore kvstore = null) { for (int index = 0; index < paramArrays.Count; index++) { var argList = paramArrays[index]; var gradList = gradArrays[index]; if (gradList[0] == null) { continue; } if (kvstore != null) { //push gradient, priority is negative index kvstore.Push(index, gradList, priority: -index); //pull back the weights kvstore.Pull(index, argList, priority: -index); } for (int k = 0; k < argList.Count; k++) { var w = argList[k]; var g = gradList[k]; updater(index * numDevice + k, g, w); } } }
private static void InitializeKvstore(KvStore kvstore, IList <List <NdArray> > paramArrays, Dictionary <string, NdArray> argParams, IList <string> paramNames, bool updateOnKvstore) { for (int idx = 0; idx < paramArrays.Count; idx++) { var paramOnDevs = paramArrays[idx]; kvstore.Init(idx, argParams[paramNames[idx]]); if (updateOnKvstore) { kvstore.Pull(idx, paramOnDevs, priority: -idx); } } }
private static void UpdateParamsOnKvstore( IList <List <NdArray> > paramArrays, IList <List <NdArray> > gradArrays, KvStore kvstore) { for (int index = 0; index < paramArrays.Count; index++) { var argList = paramArrays[index]; var gradList = gradArrays[index]; if (gradList[0] == null) { continue; } //push gradient, priority is negative index kvstore.Push(index, gradList, priority: -index); //pull back the weights kvstore.Pull(index, argList, priority: -index); } }
private static Tuple <KvStore, bool> CreateKvstore( string kvstore, int count, Dictionary <string, NdArray> argParams) { KvStore kv; if (kvstore == null) { kv = null; } else { if (count == 1 && !kvstore.Contains("dist")) { kv = null; } else { if (kvstore == "local") { //automatically select a proper local var maxSize = argParams.Select(s => Util.Prod(s.Value.GetShape())).Max(); if (maxSize < 1024 * 1024 * 16) { kvstore = "local_update_cpu"; } else { kvstore = "local_allreduce_cpu"; } } kv = new KvStore(kvstore); } } bool updateOnKvstore = !(kv == null || kv.Type.Contains("local_allreduce")); return(Tuple.Create(kv, updateOnKvstore)); }
private static void TrainMultiDevice(Symbol symbol, IList <Context> ctx, IList <string> argNames, IList <string> paramNames, IList <string> auxNames, Dictionary <string, NdArray> argParams, Dictionary <string, NdArray> auxParams, int beginEpoch, int endEpoch, int?epochSize, Optimizer optimizer, IDataIter trainData, IDataIter evalData, EvalMetric evalMetric, IList <EpochEndDelegate> epochEndCallback, IList <BatchEndDelegate> batchEndCallback, KvStore kvstore, bool updateOnKvstore, ILog logger, IList <int> workLoadList, Monitor monitor, IList <BatchEndDelegate> evalBatchEndCallback, SymbolGenerate symGen) { if (logger == null) { logger = LogManager.GetLogger(""); } var executorManager = new DataParallelExecutorManager(symbol: symbol, symGen: symGen, ctx: ctx, trainData: trainData, paramNames: paramNames, argNames: argNames, auxNames: auxNames, workLoadList: workLoadList, logger: logger); if (monitor != null) { executorManager.InstallMonitor(monitor); } executorManager.SetParams(argParams, auxParams); Action <int, NdArray, NdArray> updater = null; if (!updateOnKvstore) { updater = Optimizer.GetUpdater(optimizer); } if (kvstore != null) { InitializeKvstore(kvstore: kvstore, paramArrays: executorManager.ParamArrays, argParams: argParams, paramNames: executorManager.ParamNames, updateOnKvstore: updateOnKvstore); } if (updateOnKvstore) { kvstore?.SetOptimizer(optimizer); } //Now start training for (int epoch = 0; epoch < endEpoch - beginEpoch; epoch++) { // Training phase Stopwatch toc = new Stopwatch(); toc.Start(); evalMetric.Reset(); var nbatch = 0; // Iterate over training data. while (true) { var doReset = true; foreach (var dataBatch in trainData) { executorManager.LoadDataBatch(dataBatch); monitor?.Tic(); executorManager.Forward(isTrain: true); executorManager.Backward(); if (updateOnKvstore) { UpdateParamsOnKvstore( executorManager.ParamArrays, executorManager.GradArrays, kvstore); } else { UpdateParams(executorManager.ParamArrays, executorManager.GradArrays, updater: updater, numDevice: ctx.Count, kvstore: kvstore); } monitor?.TocPrint(); // evaluate at end, so we can lazy copy executorManager.UpdateMetric(evalMetric, dataBatch.Label); nbatch += 1; //batch callback (for print purpose) if (batchEndCallback != null) { var batchEndParams = new BatchEndParam(epoch: epoch, nbatch: nbatch, evalMetric: evalMetric, locals: Thread.CurrentThread.CurrentCulture); foreach (var call in batchEndCallback) { call(batchEndParams); } } if (epochSize != null && nbatch >= epochSize) { doReset = false; break; } } if (doReset) { logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); trainData.Reset(); } if (epochSize == null || nbatch >= epochSize) { break; } } logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}"); if (epochEndCallback != null || epoch + 1 == endEpoch) { executorManager.copy_to(argParams, auxParams); } if (epochEndCallback != null) { EpochEndParam epochEndParam = new EpochEndParam(epoch, symbol, argParams, auxParams); foreach (var callitem in epochEndCallback) { callitem(epochEndParam); } } // evaluation if (evalData != null) { evalMetric.Reset(); evalData.Reset(); int i = 0; foreach (var eval_batch in evalData) { executorManager.LoadDataBatch(eval_batch); executorManager.Forward(isTrain: false); executorManager.UpdateMetric(evalMetric, eval_batch.Label); if (evalBatchEndCallback != null) { var batchEndParams = new BatchEndParam(epoch: epoch, nbatch: i, evalMetric: evalMetric, locals: Thread.CurrentThread.CurrentCulture); foreach (var call in evalBatchEndCallback) { call(batchEndParams); } } i++; } var nameValue = evalMetric.get_name_value(); foreach (var item in nameValue) { logger.Info($"Epoch[{epoch}] Validation-{item.Name}={item.Value:0.000}"); } evalData.Reset(); } } }