Ejemplo n.º 1
0
        public Dictionary <string, float> Score(DataIter eval_data, EvalMetric eval_metric, int?num_batch = null,
                                                IBatchEndCallback[] batch_end_callback         = null,
                                                IScoreEndCallback[] score_end_callback         = null, bool reset = true, int epoch = 0,
                                                Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            if (!Binded && !ParamsInitialized)
            {
                throw new Exception("Module not binded and param initialized");
            }

            eval_metric.Reset();
            var actual_num_batch = 0;

            while (eval_data.End())
            {
                if (num_batch.HasValue && eval_data.Cursor == num_batch.Value)
                {
                    break;
                }

                var eval_batch = eval_data.Next();
                Prepare(eval_batch, sparse_row_id_fn);
                Forward(eval_batch, false);
                UpdateMetric(eval_metric, eval_batch.Label, true);
                if (batch_end_callback != null)
                {
                    foreach (var callback in batch_end_callback)
                    {
                        callback.Invoke(epoch, eval_data.Cursor, eval_metric);
                    }
                }

                actual_num_batch++;
            }

            if (score_end_callback != null)
            {
                foreach (var callback in score_end_callback)
                {
                    callback.Invoke(epoch, actual_num_batch, eval_metric, new FuncArgs());
                }
            }

            return(eval_metric.GetNameValue());
        }
Ejemplo n.º 2
0
        public IEnumerable <(NDArrayList, int, DataBatch)> IterPredict(DataIter eval_data, int?num_batch = null,
                                                                       bool reset = true, int epoch = 0, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            if (!Binded && !ParamsInitialized)
            {
                throw new Exception("Module not binded and param initialized");
            }

            if (reset)
            {
                eval_data.Reset();
            }

            while (eval_data.End())
            {
                if (num_batch.HasValue && eval_data.Cursor == num_batch.Value)
                {
                    break;
                }

                var eval_batch = eval_data.Next();
                Prepare(eval_batch, sparse_row_id_fn);
                Forward(eval_batch, false);
                var pad     = eval_batch.Pad.Value;
                var outputs = new NDArrayList();
                foreach (var list in GetOutputs())
                {
                    foreach (var @out in list)
                    {
                        outputs.Add(@out[$"0:{@out.Shape[0] - pad}"]);
                    }
                }

                yield return(outputs.ToArray(), eval_data.Cursor, eval_batch);
            }
        }
Ejemplo n.º 3
0
        internal static void TrainMultiDevice(Symbol symbol, Context[] ctx, string[] arg_names, string[] param_names, string[] aux_names, NDArrayDict arg_params, NDArrayDict aux_params, int begin_epoch, int end_epoch, int?epoch_size, Optimizer optimizer, KVStore kvstore, bool update_on_kvstore, DataIter train_data, DataIter eval_data = null, EvalMetric eval_metric = null, IEpochEndCallback epoch_end_callback = null, IBatchEndCallback batch_end_callback = null, int[] work_load_list = null, Monitor monitor = null, IEvalEndCallback eval_end_callback = null, IEvalBatchEndCallback eval_batch_end_callback = null, Func <int, Symbol> sym_gen = null)
        {
            var executor_manager = new DataParallelExecutorManager(symbol: symbol,
                                                                   ctx: ctx,
                                                                   train_data: train_data,
                                                                   arg_names: arg_names,
                                                                   param_names: param_names,
                                                                   aux_names: aux_names,
                                                                   work_load_list: work_load_list,
                                                                   sym_gen: sym_gen);

            if (monitor != null)
            {
                executor_manager.InstallMonitor(monitor);
            }

            executor_manager.SetParams(arg_params, aux_params);
            Updater updater = null;

            if (!update_on_kvstore)
            {
                updater = Optimizer.GetUpdater(optimizer);
            }
            else
            {
                kvstore.SetOptimizer(optimizer);
            }

            if (kvstore != null)
            {
                InitializeKVStore(kvstore: kvstore,
                                  param_arrays: new List <NDArrayList>()
                {
                    executor_manager.ParamArrays
                },
                                  arg_params: arg_params,
                                  param_names: executor_manager.param_names,
                                  update_on_kvstore: update_on_kvstore);
            }

            train_data.Reset();
            for (int epoch = begin_epoch; epoch < end_epoch; epoch++)
            {
                var tic = DateTime.Now;
                eval_metric.Reset();
                int nbatch = 0;
                while (true)
                {
                    bool do_reset = true;
                    while (!train_data.End())
                    {
                        var data_batch = train_data.Next();
                        executor_manager.LoadDataBatch(data_batch);
                        if (monitor != null)
                        {
                            monitor.Tic();
                        }

                        executor_manager.Forward(true);
                        executor_manager.Backward();
                        if (update_on_kvstore)
                        {
                            if (kvstore.Type.Contains("nccl"))
                            {
                                UpdateParamsOnKVStoreNCCL(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                            else
                            {
                                UpdateParamsOnKVStore(new List <NDArrayList>()
                                {
                                    executor_manager.ParamArrays
                                }, new List <NDArrayList>()
                                {
                                    executor_manager.GradArrays
                                }, kvstore, executor_manager.param_names);
                            }
                        }
                        else
                        {
                            UpdateParams(new List <NDArrayList>()
                            {
                                executor_manager.ParamArrays
                            }, new List <NDArrayList>()
                            {
                                executor_manager.GradArrays
                            }, updater, ctx.Length, kvstore, executor_manager.param_names);
                        }

                        if (monitor != null)
                        {
                            monitor.TocPrint();
                        }

                        executor_manager.UpdateMetric(eval_metric, data_batch.Label);
                        nbatch++;
                        if (batch_end_callback != null)
                        {
                            MultipleCallbacks(batch_end_callback, epoch, nbatch, eval_metric);
                        }

                        if (epoch_size.HasValue && nbatch >= epoch_size.Value)
                        {
                            do_reset = false;
                            break;
                        }
                    }

                    if (do_reset)
                    {
                        Logger.Info($"Epoch[{epoch}] Resetting Data Iterator");
                        train_data.Reset();
                    }

                    if (epoch_size.HasValue)
                    {
                        if (nbatch >= epoch_size.Value)
                        {
                            break;
                        }
                        else
                        {
                            break;
                        }
                    }
                }

                var toc = DateTime.Now;
                Logger.Info($"Epoch[{epoch}] Time cost={(toc - tic).TotalSeconds}");

                if (epoch_end_callback != null || epoch + 1 == end_epoch)
                {
                    executor_manager.CopyTo(arg_params, aux_params);
                }

                MultipleCallbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params);

                if (eval_data != null)
                {
                    eval_metric.Reset();
                    eval_data.Reset();
                    int total_num_batch = 0;
                    int i = 0;
                    while (!eval_data.End())
                    {
                        var eval_batch = eval_data.Next();
                        executor_manager.LoadDataBatch(eval_batch);
                        executor_manager.Forward();
                        executor_manager.UpdateMetric(eval_metric, eval_batch.Label);
                        if (eval_batch_end_callback != null)
                        {
                            MultipleCallbacks(eval_batch_end_callback, epoch, i, eval_metric);
                        }

                        total_num_batch++;
                    }

                    if (eval_end_callback != null)
                    {
                        MultipleCallbacks(eval_end_callback, epoch, eval_metric);
                    }
                }
            }
        }
Ejemplo n.º 4
0
        public List <NDArrayList> Predict(DataIter eval_data, int?num_batch = null, bool merge_batches = true,
                                          bool reset = true, bool always_output_list = true, Func <DataBatch, NDArrayDict> sparse_row_id_fn = null)
        {
            if (!Binded && !ParamsInitialized)
            {
                throw new Exception("Module not binded and param initialized");
            }

            if (reset)
            {
                eval_data.Reset();
            }

            var output_list  = new List <NDArrayList>();
            var output_list2 = new NDArrayList();

            while (eval_data.End())
            {
                if (num_batch.HasValue && eval_data.Cursor == num_batch.Value)
                {
                    break;
                }

                var eval_batch = eval_data.Next();
                Prepare(eval_batch, sparse_row_id_fn);
                Forward(eval_batch, false);
                var pad     = eval_batch.Pad.Value;
                var outputs = new NDArrayList();
                foreach (var list in GetOutputs())
                {
                    foreach (var @out in list)
                    {
                        outputs.Add(@out[$"0:{@out.Shape[0] - pad}"].Copy());
                    }
                }

                output_list.Add(outputs.ToArray());
            }

            if (output_list.Count == 0)
            {
                return(output_list);
            }

            if (merge_batches)
            {
                var num_outputs = output_list[0].Length;
                foreach (var @out in output_list)
                {
                    if (@out.Length != num_outputs)
                    {
                        throw new Exception("Cannot merge batches, as num of outputs is not the same " +
                                            "in mini-batches. Maybe bucketing is used?");
                    }

                    output_list2.Add(nd.Concat(@out));
                }

                return(new List <NDArrayList> {
                    output_list2.ToArray()
                });
            }

            return(output_list);
        }
Ejemplo n.º 5
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            string labelName = "label";

            var label = Symbol.Variable(labelName);

            List <uint> inputShape = new List <uint>();

            inputShape.Add(batchSize);
            inputShape.AddRange(InputShape);

            args["X"]       = new NDArray(new Shape(inputShape.ToArray()));
            args[labelName] = new NDArray(new Shape(batchSize));

            Model.InferArgsMap(mx.Device, args, args);

            var defaultInitializer = new Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                if (ParamInitializers.ContainsKey(arg.Key))
                {
                    ParamInitializers[arg.Key].Generate(arg.Value);
                }
                else
                {
                    defaultInitializer.Generate(arg.Value);
                }
            }

            using (var exec = Model.SimpleBind(mx.Device, args))
            {
                var argNames = Model.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 1; iter <= epochs; iter++)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();
                    Metric.Reset();
                    TrainMetric.Reset();
                    sw.Restart();

                    while (train.IterNext())
                    {
                        samples += batchSize;
                        var dataBatch = train.Next();

                        // Set data and label
                        dataBatch.Data[0].CopyTo(args["X"]);
                        dataBatch.Label[0].CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        TrainMetric.Update(args[labelName], exec.Output);

                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i], null);
                        }
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.IterNext())
                        {
                            var dataBatch = validation.Next();
                            dataBatch.Data[0].CopyTo(args["X"]);
                            dataBatch.Label[0].CopyTo(args[labelName]);
                            NDArray.WaitAll();
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(args[labelName], exec.Output);
                        }
                    }

                    var duration = sw.ElapsedMilliseconds == 0 ? 1 : sw.ElapsedMilliseconds;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec Train_Metric: {TrainMetric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec, Train_Metric: {TrainMetric.Get()}, Val_Metric: {Metric.Get()}");
                    }
                }
            }

            //MXNet.MXNotifyShutdown();
        }
Ejemplo n.º 6
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            var    args      = new SortedDictionary <string, NDArray>();
            string labelName = "label";
            var    label     = Symbol.Variable(labelName);

            args["X"]       = new NDArray(new Shape(batchSize, (uint)InputShape[0]));
            args[labelName] = new NDArray(new Shape(batchSize, (uint)OutputShape.Size));

            CompiledModel.InferArgsMap(GlobalParam.Device, args, args);

            var initializer = new SiaDNN.Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                initializer.Operator(arg.Key, arg.Value);
            }

            ModelOptimizer.SetParam("rescale_grad", 1.0 / batchSize);

            using (var exec = CompiledModel.SimpleBind(GlobalParam.Device, args))
            {
                var argNames = CompiledModel.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 0; iter < epochs; ++iter)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();

                    sw.Restart();

                    while (train.Next())
                    {
                        samples += batchSize;
                        var dataBatch = train.GetDataBatch();
                        // Set data and label
                        dataBatch.Data.CopyTo(args["X"]);
                        dataBatch.Label.CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i]);
                        }

                        Metric.Update(dataBatch.Label, exec.Outputs[0]);
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.Next())
                        {
                            var dataBatch = validation.GetDataBatch();
                            dataBatch.Data.CopyTo(args["X"]);
                            dataBatch.Label.CopyTo(args[labelName]);
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(dataBatch.Label, exec.Outputs[0]);
                        }
                    }


                    var duration = sw.ElapsedMilliseconds / 1000.0;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec Train_Metric: {Metric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec, Train_Metric: {Metric.Get()},  Val_Metric: {Metric.Get()}");
                    }
                }
            }

            MXNet.MXNotifyShutdown();
        }