public DataParallelExecutorManager(Symbol symbol, Context[] ctx, DataIter train_data, string[] arg_names,
                                           string[] param_names,
                                           string[] aux_names, int[] work_load_list = null, Logger logger = null, Func <int, Symbol> sym_gen = null)
        {
            num_device = ctx.Length;
            Logger.Info(string.Format("Start training with {0}", num_device));

            if (work_load_list == null)
            {
                work_load_list = new int[num_device];
                for (var i = 0; i < num_device; i++)
                {
                    work_load_list[i] = 1;
                }
            }
            else if (work_load_list.Length != num_device)
            {
                throw new MXNetException("Invalid setting for work load");
            }

            slices = ExecuterManager.SplitInputSlice(train_data.BatchSize, work_load_list);

            this.arg_names   = arg_names;
            this.param_names = param_names;
            this.aux_names   = aux_names;
            contexts         = ctx;
            execgrp          = new DataParallelExecutorGroup(symbol, arg_names, param_names, ctx, slices, train_data);
            this.symbol      = symbol;
            this.sym_gen     = sym_gen;
            if (sym_gen != null)
            {
                execgrp_bucket.Add(train_data.DefaultBucketKey, execgrp);
            }
        }
예제 #2
0
 public void Fit(DataIter train_data, DataIter eval_data = null, string eval_metric = "acc",
                 IEpochEndCallback[] epoch_end_callback  = null, IBatchEndCallback[] batch_end_callback = null,
                 string kvstore   = "local",
                 string optimizer = "sgd", Dictionary <string, object> optimizer_params = null,
                 IBatchEndCallback[] eval_end_callback       = null,
                 IBatchEndCallback[] eval_batch_end_callback = null, Initializer initializer = null,
                 NDArrayDict arg_params = null,
                 NDArrayDict aux_params = null, bool allow_missing = false, bool force_rebind = false,
                 bool force_init        = false, int begin_epoch   = 0, int?num_epoch = null, EvalMetric validation_metric = null,
                 Monitor monitor        = null, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
 {
     throw new NotImplementedException();
 }
예제 #3
0
        public Dictionary <string, float> Score(DataIter eval_data, EvalMetric eval_metric, int?num_batch = null,
                                                IBatchEndCallback[] batch_end_callback         = null,
                                                IScoreEndCallback[] score_end_callback         = null, bool reset = true, int epoch = 0,
                                                Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            if (!Binded && !ParamsInitialized)
            {
                throw new Exception("Module not binded and param initialized");
            }

            eval_metric.Reset();
            var actual_num_batch = 0;

            while (eval_data.End())
            {
                if (num_batch.HasValue && eval_data.Cursor == num_batch.Value)
                {
                    break;
                }

                var eval_batch = eval_data.Next();
                Prepare(eval_batch, sparse_row_id_fn);
                Forward(eval_batch, false);
                UpdateMetric(eval_metric, eval_batch.Label, true);
                if (batch_end_callback != null)
                {
                    foreach (var callback in batch_end_callback)
                    {
                        callback.Invoke(epoch, eval_data.Cursor, eval_metric);
                    }
                }

                actual_num_batch++;
            }

            if (score_end_callback != null)
            {
                foreach (var callback in score_end_callback)
                {
                    callback.Invoke(epoch, actual_num_batch, eval_metric, new FuncArgs());
                }
            }

            return(eval_metric.GetNameValue());
        }
예제 #4
0
        public IEnumerable <(NDArrayList, int, DataBatch)> IterPredict(DataIter eval_data, int?num_batch = null,
                                                                       bool reset = true, int epoch = 0, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            if (!Binded && !ParamsInitialized)
            {
                throw new Exception("Module not binded and param initialized");
            }

            if (reset)
            {
                eval_data.Reset();
            }

            while (eval_data.End())
            {
                if (num_batch.HasValue && eval_data.Cursor == num_batch.Value)
                {
                    break;
                }

                var eval_batch = eval_data.Next();
                Prepare(eval_batch, sparse_row_id_fn);
                Forward(eval_batch, false);
                var pad     = eval_batch.Pad.Value;
                var outputs = new NDArrayList();
                foreach (var list in GetOutputs())
                {
                    foreach (var @out in list)
                    {
                        outputs.Add(@out[$"0:{@out.Shape[0] - pad}"]);
                    }
                }

                yield return(outputs.ToArray(), eval_data.Cursor, eval_batch);
            }
        }
예제 #5
0
        internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null)
        {
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   arg_names: arg_names,
                                                                   param_names: param_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   sym_gen: sym_gen);

            if (monitor != null)
            {
                executor_manager.InstallMonitor(monitor);
            }

            executor_manager.SetParams(arg_params, aux_params);
            Updater updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            else
            {
                kvstore.SetOptimizer(optimizer);
            }

            if (kvstore != null)
            {
                InitializeKVStore(kvstore: kvstore,
                                  param_arrays: new List <NDArrayList>()
                {
                    executor_manager.ParamArrays
                },
                                  arg_params: arg_params,
                                  param_names: executor_manager.param_names,
                                  update_on_kvstore: update_on_kvstore);
            }

            train_data.Reset();
            for (int epoch = begin_epoch; epoch < end_epoch; epoch++)
            {
                var tic = DateTime.Now;
                eval_metric.Reset();
                int nbatch = 0;
                while (true)
                {
                    bool do_reset = true;
                    while (!train_data.End())
                    {
                        var data_batch = train_data.Next();
                        executor_manager.LoadDataBatch(data_batch);
                        if (monitor != null)
                        {
                            monitor.Tic();
                        }

                        executor_manager.Forward(true);
                        executor_manager.Backward();
                        if (update_on_kvstore)
                        {
                            if (kvstore.Type.Contains("nccl"))
                            {
                                UpdateParamsOnKVStoreNCCL(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                            else
                            {
                                UpdateParamsOnKVStore(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                        }
                        else
                        {
                            UpdateParams(new List <NDArrayList>()
                            {
                                executor_manager.ParamArrays
                            }, new List <NDArrayList>()
                            {
                                executor_manager.GradArrays
                            }, updater, ctx.Length, kvstore, executor_manager.param_names);
                        }

                        if (monitor != null)
                        {
                            monitor.TocPrint();
                        }

                        executor_manager.UpdateMetric(eval_metric, data_batch.Label);
                        nbatch++;
                        if (batch_end_callback != null)
                        {
                            MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric);
                        }

                        if (epoch_size.HasValue && nbatch >= epoch_size.Value)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        Logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size.HasValue)
                    {
                        if (nbatch >= epoch_size.Value)
                        {
                            break;
                        }
                        else
                        {
                            break;
                        }
                    }
                }

                var toc = DateTime.Now;
                Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}");

                if (epoch_end_callback != null || epoch + 1 == end_epoch)
                {
                    executor_manager.CopyTo(arg_params, aux_params);
                }

                MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params);

                if (eval_data != null)
                {
                    eval_metric.Reset();
                    eval_data.Reset();
                    int total_num_batch = 0;
                    int i = 0;
                    while (!eval_data.End())
                    {
                        var eval_batch = eval_data.Next();
                        executor_manager.LoadDataBatch(eval_batch);
                        executor_manager.Forward();
                        executor_manager.UpdateMetric(eval_metric, eval_batch.Label);
                        if (eval_batch_end_callback != null)
                        {
                            MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric);
                        }

                        total_num_batch++;
                    }

                    if (eval_end_callback != null)
                    {
                        MultipleCallbacks(eval_end_callback, epoch, eval_metric);
                    }
                }
            }
        }
        public DataParallelExecutorGroup(Symbol sym, string[] arg_names, string[] param_names, Context[] ctxlist,
                                         Slice[] slices, DataIter train_data, DataParallelExecutorGroup shared_group = null)
        {
            ExecuterManager.CheckArguments(sym);

            if (shared_group == null)
            {
                foreach (var item in ctxlist)
                {
                    shared_data_arrays.Add(new NDArrayDict());
                }
            }
            else
            {
                shared_data_arrays = shared_group.shared_data_arrays;
            }

            foreach (var item in train_data.ProvideData)
            {
                data_names.Add(item.Name);
            }

            foreach (var item in train_data.ProvideLabel)
            {
                label_names.Add(item.Name);
            }

            aux_names = sym.ListAuxiliaryStates().ToList();
            for (var i = 0; i < arg_names.Length; i++)
            {
                if (param_names.Contains(arg_names[i]))
                {
                    param_idx.Add(i);
                    this.param_names.Add(arg_names[i]);
                }
            }

            for (var i = 0; i < ctxlist.Length; i++)
            {
                var data_shapes = new Dictionary <string, Shape>();
                var data_types  = new Dictionary <string, DType>();
                var shapeData   = new List <int>();
                foreach (var item in train_data.ProvideData)
                {
                    shapeData = item.Shape.Data.ToList();
                    shapeData.RemoveAt(0);
                    shapeData.Insert(0, slices[i].End.Value - slices[i].Begin);
                    data_shapes[item.Name] = new Shape(shapeData);
                    data_types[item.Name]  = item.DataType;
                }

                foreach (var item in train_data.ProvideLabel)
                {
                    shapeData = item.Shape.Data.ToList();
                    shapeData.RemoveAt(0);
                    shapeData.Insert(0, slices[i].End.Value - slices[i].Begin);
                    data_shapes[item.Name] = new Shape(shapeData);
                    data_types[item.Name]  = item.DataType;
                }

                var shared_exec = shared_group == null ? null : shared_group.train_execs[i];
                var train_exec  = ExecuterManager.BindExec(sym, ctxlist[i], data_shapes, param_names, true,
                                                           shared_exec, shared_data_arrays[i], data_types);

                train_execs.Add(train_exec);
            }

            foreach (var name in data_names)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    data_arrays.Add(train_execs[i].ArgmentDictionary()[name]);
                }
            }

            foreach (var name in label_names)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    label_arrays.Add(train_execs[i].ArgmentDictionary()[name]);
                }
            }

            foreach (var idx in param_idx)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    param_arrays.Add(train_execs[i].ArgmentArrays[idx]);
                }
            }

            foreach (var idx in param_idx)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    grad_arrays.Add(train_execs[i].GradientArrays[idx]);
                }
            }

            for (var idx = 0; idx < aux_names.Count; idx++)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    aux_arrays.Add(train_execs[i].AuxiliaryArrays[i]);
                }
            }

            this.slices = slices;
        }
예제 #7
0
        public void Fit(DataIter train_data, DataIter eval_data = null, string eval_metric = "acc",
                        IEpochEndCallback[] epoch_end_callback  = null, IBatchEndCallback[] batch_end_callback = null,
                        string kvstore      = "local",
                        Optimizer optimizer = null, Dictionary <string, object> optimizer_params = null,
                        IScoreEndCallback[] eval_end_callback       = null,
                        IBatchEndCallback[] eval_batch_end_callback = null, Initializer initializer = null,
                        NDArrayDict arg_params = null,
                        NDArrayDict aux_params = null, bool allow_missing = false, bool force_rebind = false,
                        bool force_init        = false, int begin_epoch   = 0, int?num_epoch = null, EvalMetric validation_metric = null,
                        Monitor monitor        = null, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            object val;
            object name;

            if (optimizer == null)
            {
                optimizer = new SGD();
            }
            Debug.Assert(num_epoch != null, "please specify number of epochs");
            this.Bind(data_shapes: train_data.ProvideData, label_shapes: train_data.ProvideLabel, for_training: true, force_rebind: force_rebind);
            if (monitor != null)
            {
                this.InstallMonitor(monitor);
            }

            this.InitParams(initializer: initializer, arg_params: arg_params, aux_params: aux_params, allow_missing: allow_missing, force_init: force_init);
            this.InitOptimizer(kvstore: kvstore, optimizer: optimizer, optimizer_params: optimizer_params);
            if (validation_metric == null)
            {
                validation_metric = eval_metric;
            }

            var eval_metric_func = EvalMetric.Create(eval_metric, null);

            //###############################################################################
            // training loop
            //###############################################################################
            foreach (var epoch in Enumerable.Range(begin_epoch, num_epoch.Value - begin_epoch))
            {
                var tic = DateTime.Now;
                eval_metric_func.Reset();
                var nbatch          = 0;
                var data_iter       = train_data;
                var end_of_batch    = false;
                var next_data_batch = data_iter.Next();
                Dictionary <string, float> eval_name_vals = new Dictionary <string, float>();
                while (!end_of_batch)
                {
                    var data_batch = next_data_batch;
                    if (monitor != null)
                    {
                        monitor.Tic();
                    }

                    this.ForwardBackward(data_batch);
                    this.Update();

                    UpdateMetric(eval_metric_func, data_batch.Label);

                    try
                    {
                        // pre fetch next batch
                        next_data_batch = data_iter.Next();
                        this.Prepare(next_data_batch, sparse_row_id_fn: sparse_row_id_fn);
                    }
                    catch (StopIteration)
                    {
                        end_of_batch = true;
                    }
                    if (monitor != null)
                    {
                        monitor.TocPrint();
                    }

                    if (end_of_batch)
                    {
                        eval_name_vals = eval_metric_func.GetGlobalNameValue();
                    }

                    if (batch_end_callback != null)
                    {
                        foreach (var callback in batch_end_callback)
                        {
                            callback.Invoke(epoch: epoch, nbatch: nbatch, eval_metric: eval_metric);
                        }
                    }
                    nbatch += 1;
                }

                // one epoch of training is finished
                foreach (var item in eval_name_vals)
                {
                    name = item.Key;
                    val  = item.Value;
                    Logger.Info($"Epoch[{epoch}] Train-{name}={val}");
                }

                var toc = DateTime.Now;

                Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}");
                // sync aux params across devices
                (arg_params, aux_params) = this.GetParams();
                this.SetParams(arg_params, aux_params);
                if (epoch_end_callback != null)
                {
                    foreach (var callback in epoch_end_callback)
                    {
                        callback.Invoke(epoch, this.Symbol, arg_params, aux_params);
                    }
                }
                //----------------------------------------
                // evaluation on validation set
                if (eval_data != null)
                {
                    var res = this.Score(eval_data, validation_metric, score_end_callback: eval_end_callback, batch_end_callback: eval_batch_end_callback, epoch: epoch);
                    //TODO: pull this into default
                    foreach (var item in res)
                    {
                        name = item.Key;
                        val  = item.Value;
                        Logger.Info($"Epoch[{epoch}] Validation-{name}={val}");
                    }
                }
                // end of 1 epoch, reset the data-iter for another epoch
                train_data.Reset();
            }
        }
예제 #8
0
        public List <NDArrayList> Predict(DataIter eval_data, int?num_batch = null, bool merge_batches = true,
                                          bool reset = true, bool always_output_list = true, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            if (!Binded && !ParamsInitialized)
            {
                throw new Exception("Module not binded and param initialized");
            }

            if (reset)
            {
                eval_data.Reset();
            }

            var output_list  = new List <NDArrayList>();
            var output_list2 = new NDArrayList();

            while (eval_data.End())
            {
                if (num_batch.HasValue && eval_data.Cursor == num_batch.Value)
                {
                    break;
                }

                var eval_batch = eval_data.Next();
                Prepare(eval_batch, sparse_row_id_fn);
                Forward(eval_batch, false);
                var pad     = eval_batch.Pad.Value;
                var outputs = new NDArrayList();
                foreach (var list in GetOutputs())
                {
                    foreach (var @out in list)
                    {
                        outputs.Add(@out[$"0:{@out.Shape[0] - pad}"].Copy());
                    }
                }

                output_list.Add(outputs.ToArray());
            }

            if (output_list.Count == 0)
            {
                return(output_list);
            }

            if (merge_batches)
            {
                var num_outputs = output_list[0].Length;
                foreach (var @out in output_list)
                {
                    if (@out.Length != num_outputs)
                    {
                        throw new Exception("Cannot merge batches, as num of outputs is not the same " +
                                            "in mini-batches. Maybe bucketing is used?");
                    }

                    output_list2.Add(nd.Concat(@out));
                }

                return(new List <NDArrayList> {
                    output_list2.ToArray()
                });
            }

            return(output_list);
        }
예제 #9
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            string labelName = "label";

            var label = Symbol.Variable(labelName);

            List <uint> inputShape = new List <uint>();

            inputShape.Add(batchSize);
            inputShape.AddRange(InputShape);

            args["X"]       = new NDArray(new Shape(inputShape.ToArray()));
            args[labelName] = new NDArray(new Shape(batchSize));

            Model.InferArgsMap(mx.Device, args, args);

            var defaultInitializer = new Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                if (ParamInitializers.ContainsKey(arg.Key))
                {
                    ParamInitializers[arg.Key].Generate(arg.Value);
                }
                else
                {
                    defaultInitializer.Generate(arg.Value);
                }
            }

            using (var exec = Model.SimpleBind(mx.Device, args))
            {
                var argNames = Model.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 1; iter <= epochs; iter++)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();
                    Metric.Reset();
                    TrainMetric.Reset();
                    sw.Restart();

                    while (train.IterNext())
                    {
                        samples += batchSize;
                        var dataBatch = train.Next();

                        // Set data and label
                        dataBatch.Data[0].CopyTo(args["X"]);
                        dataBatch.Label[0].CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        TrainMetric.Update(args[labelName], exec.Output);

                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i], null);
                        }
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.IterNext())
                        {
                            var dataBatch = validation.Next();
                            dataBatch.Data[0].CopyTo(args["X"]);
                            dataBatch.Label[0].CopyTo(args[labelName]);
                            NDArray.WaitAll();
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(args[labelName], exec.Output);
                        }
                    }

                    var duration = sw.ElapsedMilliseconds == 0 ? 1 : sw.ElapsedMilliseconds;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec Train_Metric: {TrainMetric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec, Train_Metric: {TrainMetric.Get()}, Val_Metric: {Metric.Get()}");
                    }
                }
            }

            //MXNet.MXNotifyShutdown();
        }
예제 #10
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            var    args      = new SortedDictionary <string, NDArray>();
            string labelName = "label";
            var    label     = Symbol.Variable(labelName);

            args["X"]       = new NDArray(new Shape(batchSize, (uint)InputShape[0]));
            args[labelName] = new NDArray(new Shape(batchSize, (uint)OutputShape.Size));

            CompiledModel.InferArgsMap(GlobalParam.Device, args, args);

            var initializer = new SiaDNN.Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                initializer.Operator(arg.Key, arg.Value);
            }

            ModelOptimizer.SetParam("rescale_grad", 1.0 / batchSize);

            using (var exec = CompiledModel.SimpleBind(GlobalParam.Device, args))
            {
                var argNames = CompiledModel.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 0; iter < epochs; ++iter)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();

                    sw.Restart();

                    while (train.Next())
                    {
                        samples += batchSize;
                        var dataBatch = train.GetDataBatch();
                        // Set data and label
                        dataBatch.Data.CopyTo(args["X"]);
                        dataBatch.Label.CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i]);
                        }

                        Metric.Update(dataBatch.Label, exec.Outputs[0]);
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.Next())
                        {
                            var dataBatch = validation.GetDataBatch();
                            dataBatch.Data.CopyTo(args["X"]);
                            dataBatch.Label.CopyTo(args[labelName]);
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(dataBatch.Label, exec.Outputs[0]);
                        }
                    }


                    var duration = sw.ElapsedMilliseconds / 1000.0;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec Train_Metric: {Metric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec, Train_Metric: {Metric.Get()},  Val_Metric: {Metric.Get()}");
                    }
                }
            }

            MXNet.MXNotifyShutdown();
        }