示例#1
0
        private static void _update_params(
            List <List <NDArray> > param_arrays,
            List <List <NDArray> > grad_arrays,
            Action <int, NDArray, NDArray> updater,
            int num_device, KVStore kvstore = null)
        {
            for (int index = 0; index < param_arrays.Count; index++)
            {
                var arg_list  = param_arrays[index];
                var grad_list = grad_arrays[index];
                if (grad_list[0] == null)
                {
                    continue;
                }
                if (kvstore != null)
                {
                    //push gradient, priority is negative index
                    kvstore.Push(index, grad_list, priority: -index);
                    //pull back the weights
                    kvstore.Pull(index, arg_list, priority: -index);
                }

                for (int k = 0; k < arg_list.Count; k++)
                {
                    var w = arg_list[k];
                    var g = grad_list[k];


                    updater(index * num_device + k, g, w);
                }
            }
        }
示例#2
0
        private static void _initialize_kvstore(KVStore kvstore,
                                                List <List <NDArray> > param_arrays,
                                                Dictionary <string, NDArray> arg_params,
                                                List <string> param_names,
                                                bool update_on_kvstore)
        {
            for (int idx = 0; idx < param_arrays.Count; idx++)
            {
                var param_on_devs = param_arrays[idx];
                kvstore.Init(idx, arg_params[param_names[idx]]);

                if (update_on_kvstore)
                {
                    kvstore.Pull(idx, param_on_devs, priority: -idx);
                }
            }
        }
示例#3
0
 private static void _update_params_on_kvstore(
     List <List <NDArray> > param_arrays,
     List <List <NDArray> > grad_arrays,
     KVStore kvstore)
 {
     for (int index = 0; index < param_arrays.Count; index++)
     {
         var arg_list  = param_arrays[index];
         var grad_list = grad_arrays[index];
         if (grad_list[0] == null)
         {
             continue;
         }
         //push gradient, priority is negative index
         kvstore.Push(index, grad_list, priority: -index);
         //pull back the weights
         kvstore.Pull(index, arg_list, priority: -index);
     }
 }
示例#4
0
        private static Tuple <KVStore, bool> _create_kvstore(
            string kvstore, int count, Dictionary <string, NDArray> arg_params)
        {
            KVStore kv;

            if (kvstore == null)
            {
                kv = null;
            }
            else
            {
                if (count == 1 && !kvstore.Contains("dist"))
                {
                    kv = null;
                }
                else
                {
                    if (kvstore == "local")
                    {
                        //automatically select a proper local
                        var max_size = arg_params.Select(s => Util.Prod(s.Value.Get_shape())).Max();
                        if (max_size < 1024 * 1024 * 16)
                        {
                            kvstore = "local_update_cpu";
                        }
                        else
                        {
                            kvstore = "local_allreduce_cpu";
                        }
                    }
                    kv = new KVStore(kvstore);
                }
            }

            bool update_on_kvstore = !(kv == null || kv.type.Contains("local_allreduce"));


            return(Tuple.Create(kv, update_on_kvstore));
        }
示例#5
0
        private static void _train_multi_device(Symbol symbol, List <Context> ctx, List <string> arg_names,
                                                List <string> param_names, List <string> aux_names, Dictionary <string, NDArray> arg_params,
                                                Dictionary <string, NDArray> aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer,
                                                IDataIter train_data, IDataIter eval_data, EvalMetric eval_metric, List <Action> epoch_end_callback,
                                                List <Action <BatchEndParam> > batch_end_callback, KVStore kvstore, bool update_on_kvstore, ILog logger, List <int> work_load_list,
                                                Monitor monitor, Action eval_batch_end_callback, SymbolGenerate sym_gen)
        {
            if (logger == null)
            {
                logger = LogManager.GetLogger("");
            }
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   sym_gen: sym_gen,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   param_names: param_names,
                                                                   arg_names: arg_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   logger: logger);


            if (monitor != null)
            {
                executor_manager.install_monitor(monitor);
            }
            executor_manager.set_params(arg_params, aux_params);

            Action <int, NDArray, NDArray> updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.get_updater(optimizer);
            }
            if (kvstore != null)
            {
                _initialize_kvstore(kvstore: kvstore,
                                    param_arrays: executor_manager.param_arrays,
                                    arg_params: arg_params,
                                    param_names: executor_manager.param_names,
                                    update_on_kvstore: update_on_kvstore);
            }

            if (update_on_kvstore)
            {
                kvstore.set_optimizer(optimizer);
            }

            //Now start training
            for (int epoch = 0; epoch < end_epoch - begin_epoch; epoch++)
            {
                // Training phase
                Stopwatch toc = new Stopwatch();
                toc.Start();
                eval_metric.Reset();
                var nbatch = 0;
                // Iterate over training data.

                while (true)
                {
                    var do_reset = true;
                    foreach (var data_batch in train_data)
                    {
                        executor_manager.load_data_batch(data_batch);

                        monitor?.Tic();


                        executor_manager.Forward(is_train: true);
                        executor_manager.Backward();



                        if (update_on_kvstore)
                        {
                            _update_params_on_kvstore(
                                executor_manager.param_arrays,
                                executor_manager.grad_arrays,
                                kvstore);
                        }
                        else
                        {
                            _update_params(executor_manager.param_arrays,
                                           executor_manager.grad_arrays,
                                           updater: updater,
                                           num_device: ctx.Count,
                                           kvstore: kvstore);
                        }
                        monitor?.toc_print();
                        // evaluate at end, so we can lazy copy
                        executor_manager.update_metric(eval_metric, data_batch.label);

                        nbatch += 1;
                        //batch callback (for print purpose)

                        if (batch_end_callback != null)
                        {
                            var batch_end_params = new BatchEndParam(epoch: epoch,
                                                                     nbatch: nbatch,
                                                                     eval_metric: eval_metric,
                                                                     locals: Thread.CurrentThread.CurrentCulture);

                            foreach (var call in batch_end_callback)
                            {
                                call(batch_end_params);
                            }
                        }
                        if (epoch_size != null && nbatch >= epoch_size)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size == null || nbatch >= epoch_size)
                    {
                        break;
                    }
                }


                logger.Info($"Epoch[{epoch}] Time cost={(toc.ElapsedMilliseconds/1000):.000}");
            }
        }