コード例 #1
0
        public void Add(List <float> trainLoss, List <float> trainMetric, List <float> valLoss, List <float> valMetric)
        {
            TrainLoss.Add(trainLoss.Average());
            TrainMetric.Add(trainMetric.Average());

            if (valLoss.Count > 0)
            {
                ValLoss.Add(valLoss.Average());
            }

            if (valMetric.Count > 0)
            {
                ValMetric.Add(valMetric.Average());
            }
        }
コード例 #2
0
ファイル: Training.cs プロジェクト: sportbilly21/MxNet.Sharp
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            string labelName = "label";

            var label = Symbol.Variable(labelName);

            List <uint> inputShape = new List <uint>();

            inputShape.Add(batchSize);
            inputShape.AddRange(InputShape);

            args["X"]       = new NDArray(new Shape(inputShape.ToArray()));
            args[labelName] = new NDArray(new Shape(batchSize));

            Model.InferArgsMap(mx.Device, args, args);

            var defaultInitializer = new Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                if (ParamInitializers.ContainsKey(arg.Key))
                {
                    ParamInitializers[arg.Key].Generate(arg.Value);
                }
                else
                {
                    defaultInitializer.Generate(arg.Value);
                }
            }

            using (var exec = Model.SimpleBind(mx.Device, args))
            {
                var argNames = Model.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 1; iter <= epochs; iter++)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();
                    Metric.Reset();
                    TrainMetric.Reset();
                    sw.Restart();

                    while (train.IterNext())
                    {
                        samples += batchSize;
                        var dataBatch = train.Next();

                        // Set data and label
                        dataBatch.Data[0].CopyTo(args["X"]);
                        dataBatch.Label[0].CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        TrainMetric.Update(args[labelName], exec.Output);

                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i], null);
                        }
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.IterNext())
                        {
                            var dataBatch = validation.Next();
                            dataBatch.Data[0].CopyTo(args["X"]);
                            dataBatch.Label[0].CopyTo(args[labelName]);
                            NDArray.WaitAll();
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(args[labelName], exec.Output);
                        }
                    }

                    var duration = sw.ElapsedMilliseconds == 0 ? 1 : sw.ElapsedMilliseconds;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec Train_Metric: {TrainMetric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec, Train_Metric: {TrainMetric.Get()}, Val_Metric: {Metric.Get()}");
                    }
                }
            }

            //MXNet.MXNotifyShutdown();
        }