private static void _update_params( List <List <NDArray> > param_arrays, List <List <NDArray> > grad_arrays, Action <int, NDArray, NDArray> updater, int num_device, KVStore kvstore = null) { for (int index = 0; index < param_arrays.Count; index++) { var arg_list = param_arrays[index]; var grad_list = grad_arrays[index]; if (grad_list[0] == null) { continue; } if (kvstore != null) { //push gradient, priority is negative index kvstore.Push(index, grad_list, priority: -index); //pull back the weights kvstore.Pull(index, arg_list, priority: -index); } for (int k = 0; k < arg_list.Count; k++) { var w = arg_list[k]; var g = grad_list[k]; updater(index * num_device + k, g, w); } } }
private static void _initialize_kvstore(KVStore kvstore, List <List <NDArray> > param_arrays, Dictionary <string, NDArray> arg_params, List <string> param_names, bool update_on_kvstore) { for (int idx = 0; idx < param_arrays.Count; idx++) { var param_on_devs = param_arrays[idx]; kvstore.Init(idx, arg_params[param_names[idx]]); if (update_on_kvstore) { kvstore.Pull(idx, param_on_devs, priority: -idx); } } }
private static void _update_params_on_kvstore( List <List <NDArray> > param_arrays, List <List <NDArray> > grad_arrays, KVStore kvstore) { for (int index = 0; index < param_arrays.Count; index++) { var arg_list = param_arrays[index]; var grad_list = grad_arrays[index]; if (grad_list[0] == null) { continue; } //push gradient, priority is negative index kvstore.Push(index, grad_list, priority: -index); //pull back the weights kvstore.Pull(index, arg_list, priority: -index); } }
private static Tuple <KVStore, bool> _create_kvstore( string kvstore, int count, Dictionary <string, NDArray> arg_params) { KVStore kv; if (kvstore == null) { kv = null; } else { if (count == 1 && !kvstore.Contains("dist")) { kv = null; } else { if (kvstore == "local") { //automatically select a proper local var max_size = arg_params.Select(s => Util.Prod(s.Value.Get_shape())).Max(); if (max_size < 1024 * 1024 * 16) { kvstore = "local_update_cpu"; } else { kvstore = "local_allreduce_cpu"; } } kv = new KVStore(kvstore); } } bool update_on_kvstore = !(kv == null || kv.type.Contains("local_allreduce")); return(Tuple.Create(kv, update_on_kvstore)); }
private static void _train_multi_device(Symbol symbol, List <Context> ctx, List <string> arg_names, List <string> param_names, List <string> aux_names, Dictionary <string, NDArray> arg_params, Dictionary <string, NDArray> aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, IDataIter train_data, IDataIter eval_data, EvalMetric eval_metric, List <Action> epoch_end_callback, List <Action <BatchEndParam> > batch_end_callback, KVStore kvstore, bool update_on_kvstore, ILog logger, List <int> work_load_list, Monitor monitor, Action eval_batch_end_callback, SymbolGenerate sym_gen) { if (logger == null) { logger = LogManager.GetLogger(""); } var executor_manager = new DataParallelExecutorManager(symbol: symbol, sym_gen: sym_gen, ctx: ctx, train_data: train_data, param_names: param_names, arg_names: arg_names, aux_names: aux_names, work_load_list: work_load_list, logger: logger); if (monitor != null) { executor_manager.install_monitor(monitor); } executor_manager.set_params(arg_params, aux_params); Action <int, NDArray, NDArray> updater = null; if (!update_on_kvstore) { updater = Optimizer.get_updater(optimizer); } if (kvstore != null) { _initialize_kvstore(kvstore: kvstore, param_arrays: executor_manager.param_arrays, arg_params: arg_params, param_names: executor_manager.param_names, update_on_kvstore: update_on_kvstore); } if (update_on_kvstore) { kvstore.set_optimizer(optimizer); } //Now start training for (int epoch = 0; epoch < end_epoch - begin_epoch; epoch++) { // Training phase Stopwatch toc = new Stopwatch(); toc.Start(); eval_metric.Reset(); var nbatch = 0; // Iterate over training data. while (true) { var do_reset = true; foreach (var data_batch in train_data) { executor_manager.load_data_batch(data_batch); monitor?.Tic(); executor_manager.Forward(is_train: true); executor_manager.Backward(); if (update_on_kvstore) { _update_params_on_kvstore( executor_manager.param_arrays, executor_manager.grad_arrays, kvstore); } else { _update_params(executor_manager.param_arrays, executor_manager.grad_arrays, updater: updater, num_device: ctx.Count, kvstore: kvstore); } monitor?.toc_print(); // evaluate at end, so we can lazy copy executor_manager.update_metric(eval_metric, data_batch.label); nbatch += 1; //batch callback (for print purpose) if (batch_end_callback != null) { var batch_end_params = new BatchEndParam(epoch: epoch, nbatch: nbatch, eval_metric: eval_metric, locals: Thread.CurrentThread.CurrentCulture); foreach (var call in batch_end_callback) { call(batch_end_params); } } if (epoch_size != null && nbatch >= epoch_size) { do_reset = false; break; } } if (do_reset) { logger.Info($"Epoch[{epoch}] Resetting Data Iterator"); train_data.Reset(); } if (epoch_size == null || nbatch >= epoch_size) { break; } } logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}"); } }