public DataParallelExecutorManager(Symbol symbol, SymbolGenerate symGen, IList <Context> ctx, IDataIter trainData, IList <string> paramNames, IList <string> argNames, IList <string> auxNames, IList <int> workLoadList, ILog logger) { if (logger == null) { logger = LogManager.GetLogger(""); } this._logger = logger; var numDevice = ctx.Count; logger.Info("Start training with " + string.Join("", ctx)); if (workLoadList == null) { workLoadList = Enumerable.Repeat(1, numDevice).ToList(); } Util.Assert(workLoadList.Count == numDevice, "Invalid settings for work load. "); var slices = SplitInputSlice(trainData.BatchSize, workLoadList); this._slices = slices; this._argNames = argNames; this.ParamNames = paramNames; this._auxNames = auxNames; this._ctx = ctx; this._execgrp = new DataParallelExecutorGroup(symbol, this._argNames, this.ParamNames, this._ctx, this._slices, trainData); this._symbol = symbol; this._symGen = symGen; this._currExecgrp = null; // this is set when data is loaded if (this._symGen != null) { this._execgrpBucket = new Dictionary <string, DataParallelExecutorGroup>() { { trainData.DefaultBucketKey, this._execgrp } }; } }
public DataParallelExecutorManager(Symbol symbol, SymbolGenerate sym_gen, List <Context> ctx, IDataIter train_data, List <string> param_names, List <string> arg_names, List <string> aux_names, List <int> work_load_list, ILog logger) { if (logger == null) { logger = LogManager.GetLogger(""); } this._logger = logger; var num_device = ctx.Count; logger.Info("Start training with " + string.Join("", ctx)); if (work_load_list == null) { work_load_list = Enumerable.Repeat(1, num_device).ToList(); } Util.Assert(work_load_list.Count == num_device, "Invalid settings for work load. "); var slices = _split_input_slice(train_data.batch_size, work_load_list); this._slices = slices; this._arg_names = arg_names; this.param_names = param_names; this._aux_names = aux_names; this._ctx = ctx; this._execgrp = new DataParallelExecutorGroup(symbol, this._arg_names, this.param_names, this._ctx, this._slices, train_data); this._symbol = symbol; this._sym_gen = sym_gen; this._curr_execgrp = null; // this is set when data is loaded if (this._sym_gen != null) { this._execgrp_bucket = new Dictionary <string, DataParallelExecutorGroup>() { { train_data.default_bucket_key, this._execgrp } }; } }
public List <SingleNArray> Predict(IDataIter inputX, int?numBatch = null, bool returnData = false, bool reset = true) { if (reset) { inputX.Reset(); } var dataShapes = inputX.ProvideData; var dataNames = dataShapes.Select(s => s.Key).ToList(); InitPredictor(dataShapes); var batchSize = inputX.BatchSize; var dataArrays = dataNames.Select(name => this._predExec.ArgDict[name]).ToList(); var outputList = this._predExec.Outputs.Select(s => new List <SingleNArray>()).ToList(); List <List <SingleNArray> > dataList = null; List <List <SingleNArray> > labelList = null; if (returnData) { dataList = inputX.ProvideData.Select(s => new List <SingleNArray>()).ToList(); labelList = inputX.ProvideLabel.Select(s => new List <SingleNArray>()).ToList(); } int i = 0; foreach (var batch in inputX) { ExecutorManager.LoadData(batch, dataArrays); this._predExec.Forward(isTrain: false); var padded = batch.Pad; var realSize = batchSize - padded; foreach (var vitem in outputList.Zip(this._predExec.Outputs, Tuple.Create)) { vitem.Item1.Add(vitem.Item2.Slice(0, (uint)realSize).AsNumerics()); } if (returnData) { for (int j = 0; j < batch.Data.Count; j++) { var x = batch.Data[j]; dataList[j].Add(x.Slice(0, (uint)realSize).AsNumerics()); } for (int j = 0; j < batch.Data.Count; j++) { var x = batch.Label[j]; labelList[j].Add(x.Slice(0, (uint)realSize).AsNumerics()); } } i += 1; if (numBatch != null && i == numBatch.Value) { break; } } var outputs = outputList.Select(s => SingleNArray.Concatenate(0, s.ToArray())).ToList(); if (returnData) { var data = dataList.Select(s => SingleNArray.Concatenate(0, s.ToArray())); var label = labelList.Select(s => SingleNArray.Concatenate(0, s.ToArray())); } return(outputs); }
private static void TrainMultiDevice(Symbol symbol, IList <Context> ctx, IList <string> argNames, IList <string> paramNames, IList <string> auxNames, Dictionary <string, NdArray> argParams, Dictionary <string, NdArray> auxParams, int beginEpoch, int endEpoch, int?epochSize, Optimizer optimizer, IDataIter trainData, IDataIter evalData, EvalMetric evalMetric, IList <EpochEndDelegate> epochEndCallback, IList <BatchEndDelegate> batchEndCallback, KvStore kvstore, bool updateOnKvstore, ILog logger, IList <int> workLoadList, Monitor monitor, IList <BatchEndDelegate> evalBatchEndCallback, SymbolGenerate symGen) { if (logger == null) { logger = LogManager.GetLogger(""); } var executorManager = new DataParallelExecutorManager(symbol: symbol, symGen: symGen, ctx: ctx, trainData: trainData, paramNames: paramNames, argNames: argNames, auxNames: auxNames, workLoadList: workLoadList, logger: logger); if (monitor != null) { executorManager.InstallMonitor(monitor); } executorManager.SetParams(argParams, auxParams); Action <int, NdArray, NdArray> updater = null; if (!updateOnKvstore) { updater = Optimizer.GetUpdater(optimizer); } if (kvstore != null) { InitializeKvstore(kvstore: kvstore, paramArrays: executorManager.ParamArrays, argParams: argParams, paramNames: executorManager.ParamNames, updateOnKvstore: updateOnKvstore); } if (updateOnKvstore) { kvstore?.SetOptimizer(optimizer); } //Now start training for (int epoch = 0; epoch < endEpoch - beginEpoch; epoch++) { // Training phase Stopwatch toc = new Stopwatch(); toc.Start(); evalMetric.Reset(); var nbatch = 0; // Iterate over training data. while (true) { var doReset = true; foreach (var dataBatch in trainData) { executorManager.LoadDataBatch(dataBatch); monitor?.Tic(); executorManager.Forward(isTrain: true); executorManager.Backward(); if (updateOnKvstore) { UpdateParamsOnKvstore( executorManager.ParamArrays, executorManager.GradArrays, kvstore); } else { UpdateParams(executorManager.ParamArrays, executorManager.GradArrays, updater: updater, numDevice: ctx.Count, kvstore: kvstore); } monitor?.TocPrint(); // evaluate at end, so we can lazy copy executorManager.UpdateMetric(evalMetric, dataBatch.Label); nbatch += 1; //batch callback (for print purpose) if (batchEndCallback != null) { var batchEndParams = new BatchEndParam(epoch: epoch, nbatch: nbatch, evalMetric: evalMetric, locals: Thread.CurrentThread.CurrentCulture); foreach (var call in batchEndCallback) { call(batchEndParams); } } if (epochSize != null && nbatch >= epochSize) { doReset = false; break; } } if (doReset) { logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); trainData.Reset(); } if (epochSize == null || nbatch >= epochSize) { break; } } logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}"); if (epochEndCallback != null || epoch + 1 == endEpoch) { executorManager.copy_to(argParams, auxParams); } if (epochEndCallback != null) { EpochEndParam epochEndParam = new EpochEndParam(epoch, symbol, argParams, auxParams); foreach (var callitem in epochEndCallback) { callitem(epochEndParam); } } // evaluation if (evalData != null) { evalMetric.Reset(); evalData.Reset(); int i = 0; foreach (var eval_batch in evalData) { executorManager.LoadDataBatch(eval_batch); executorManager.Forward(isTrain: false); executorManager.UpdateMetric(evalMetric, eval_batch.Label); if (evalBatchEndCallback != null) { var batchEndParams = new BatchEndParam(epoch: epoch, nbatch: i, evalMetric: evalMetric, locals: Thread.CurrentThread.CurrentCulture); foreach (var call in evalBatchEndCallback) { call(batchEndParams); } } i++; } var nameValue = evalMetric.get_name_value(); foreach (var item in nameValue) { logger.Info($"Epoch[{epoch}] Validation-{item.Name}={item.Value:0.000}"); } evalData.Reset(); } } }
public void Fit(IDataIter trainData, IDataIter evalData, EvalMetric evalMetric = null, IList <EpochEndDelegate> epochEndCallback = null, IList <BatchEndDelegate> batchEndCallback = null, string kvstoreInput = "local", ILog logger = null, IList <int> workLoadList = null, Monitor monitor = null, IList <BatchEndDelegate> evalBatchEndCallback = null ) { var data = trainData; if (this._symGen != null) { this._symbol = this._symGen(data.DefaultBucketKey); this.CheckArguments(); } this._kwargs["sym"] = this._symbol; var initParamsTemp = this.InitParams(data.ProvideData.Concat(data.ProvideLabel).ToDictionary(x => x.Key, y => y.Value)); var argNames = initParamsTemp.Item1; var paramNames = initParamsTemp.Item2; var auxNames = initParamsTemp.Item3; if (evalMetric == null) { evalMetric = "acc"; } //create kvstore var createKvstoreTemp = CreateKvstore(kvstoreInput, _ctx.Count, ArgParams); var kvstore = createKvstoreTemp.Item1; var updateOnKvstore = createKvstoreTemp.Item2; var paramIdx2Name = new Dictionary <int, string>(); if (updateOnKvstore) { paramIdx2Name = paramNames.Select((x, i) => new { i = i, x = x }).ToDictionary(k => k.i, v => v.x); } else { for (int i = 0; i < paramNames.Count; i++) { for (int k = 0; k < _ctx.Count; k++) { paramIdx2Name[i * _ctx.Count + k] = paramNames[i]; } } } _kwargs["param_idx2name"] = paramIdx2Name; //(TODO)init optmizer TrainMultiDevice(this._symbol, this._ctx, argNames, paramNames, auxNames, this.ArgParams, this.AuxParams, beginEpoch: this._beginEpoch, endEpoch: this._numEpoch, epochSize: this._epochSize, optimizer: _optimizer, trainData: data, evalData: evalData, evalMetric: evalMetric, epochEndCallback: epochEndCallback, batchEndCallback: batchEndCallback, kvstore: kvstore, updateOnKvstore: updateOnKvstore, logger: logger, workLoadList: workLoadList, monitor: monitor, evalBatchEndCallback: evalBatchEndCallback, symGen: this._symGen); }
private static void _train_multi_device(Symbol symbol, List <Context> ctx, List <string> arg_names, List <string> param_names, List <string> aux_names, Dictionary <string, NDArray> arg_params, Dictionary <string, NDArray> aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, IDataIter train_data, IDataIter eval_data, EvalMetric eval_metric, List <Action> epoch_end_callback, List <Action <BatchEndParam> > batch_end_callback, KVStore kvstore, bool update_on_kvstore, ILog logger, List <int> work_load_list, Monitor monitor, Action eval_batch_end_callback, SymbolGenerate sym_gen) { if (logger == null) { logger = LogManager.GetLogger(""); } var executor_manager = new DataParallelExecutorManager(symbol: symbol, sym_gen: sym_gen, ctx: ctx, train_data: train_data, param_names: param_names, arg_names: arg_names, aux_names: aux_names, work_load_list: work_load_list, logger: logger); if (monitor != null) { executor_manager.install_monitor(monitor); } executor_manager.set_params(arg_params, aux_params); Action <int, NDArray, NDArray> updater = null; if (!update_on_kvstore) { updater = Optimizer.get_updater(optimizer); } if (kvstore != null) { _initialize_kvstore(kvstore: kvstore, param_arrays: executor_manager.param_arrays, arg_params: arg_params, param_names: executor_manager.param_names, update_on_kvstore: update_on_kvstore); } if (update_on_kvstore) { kvstore.set_optimizer(optimizer); } //Now start training for (int epoch = 0; epoch < end_epoch - begin_epoch; epoch++) { // Training phase Stopwatch toc = new Stopwatch(); toc.Start(); eval_metric.Reset(); var nbatch = 0; // Iterate over training data. while (true) { var do_reset = true; foreach (var data_batch in train_data) { executor_manager.load_data_batch(data_batch); monitor?.Tic(); executor_manager.Forward(is_train: true); executor_manager.Backward(); if (update_on_kvstore) { _update_params_on_kvstore( executor_manager.param_arrays, executor_manager.grad_arrays, kvstore); } else { _update_params(executor_manager.param_arrays, executor_manager.grad_arrays, updater: updater, num_device: ctx.Count, kvstore: kvstore); } monitor?.toc_print(); // evaluate at end, so we can lazy copy executor_manager.update_metric(eval_metric, data_batch.label); nbatch += 1; //batch callback (for print purpose) if (batch_end_callback != null) { var batch_end_params = new BatchEndParam(epoch: epoch, nbatch: nbatch, eval_metric: eval_metric, locals: Thread.CurrentThread.CurrentCulture); foreach (var call in batch_end_callback) { call(batch_end_params); } } if (epoch_size != null && nbatch >= epoch_size) { do_reset = false; break; } } if (do_reset) { logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); train_data.Reset(); } if (epoch_size == null || nbatch >= epoch_size) { break; } } logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}"); } }
public void Fit(IDataIter train_data, IDataIter eval_data, metric.EvalMetric eval_metric = null, List <Action> epoch_end_callback = null, List <Action <BatchEndParam> > batch_end_callback = null, string kvstore_input = "local", ILog logger = null, List <int> work_load_list = null, Monitor monitor = null, Action eval_batch_end_callback = null ) { var data = train_data; if (this._sym_gen != null) { this._symbol = this._sym_gen(data.default_bucket_key); this._check_arguments(); } this._kwargs["sym"] = this._symbol; var init_params_temp = this._init_params(data.provide_data.Concat(data.provide_label).ToDictionary(x => x.Key, y => y.Value)); var arg_names = init_params_temp.Item1; var param_names = init_params_temp.Item2; var aux_names = init_params_temp.Item3; if (eval_metric == null) { eval_metric = "acc"; } //create kvstore var create_kvstore_temp = _create_kvstore(kvstore_input, _ctx.Count, _arg_params); var kvstore = create_kvstore_temp.Item1; var update_on_kvstore = create_kvstore_temp.Item2; var param_idx2_name = new Dictionary <int, string>(); if (update_on_kvstore) { param_idx2_name = param_names.Select((x, i) => new { i = i, x = x }).ToDictionary(k => k.i, v => v.x); } else { for (int i = 0; i < param_names.Count; i++) { for (int k = 0; k < _ctx.Count; k++) { param_idx2_name[i * _ctx.Count + k] = param_names[i]; } } } _kwargs["param_idx2name"] = param_idx2_name; //(TODO)init optmizer _train_multi_device(this._symbol, this._ctx, arg_names, param_names, aux_names, this._arg_params, this._aux_params, begin_epoch: this._begin_epoch, end_epoch: this._num_epoch, epoch_size: this._epoch_size, optimizer: _optimizer, train_data: data, eval_data: eval_data, eval_metric: eval_metric, epoch_end_callback: epoch_end_callback, batch_end_callback: batch_end_callback, kvstore: kvstore, update_on_kvstore: update_on_kvstore, logger: logger, work_load_list: work_load_list, monitor: monitor, eval_batch_end_callback: eval_batch_end_callback, sym_gen: this._sym_gen); }