示例#1
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            string labelName = "label";

            var label = Symbol.Variable(labelName);

            List <uint> inputShape = new List <uint>();

            inputShape.Add(batchSize);
            inputShape.AddRange(InputShape);

            args["X"]       = new NDArray(new Shape(inputShape.ToArray()));
            args[labelName] = new NDArray(new Shape(batchSize));

            Model.InferArgsMap(mx.Device, args, args);

            var defaultInitializer = new Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                if (ParamInitializers.ContainsKey(arg.Key))
                {
                    ParamInitializers[arg.Key].Generate(arg.Value);
                }
                else
                {
                    defaultInitializer.Generate(arg.Value);
                }
            }

            using (var exec = Model.SimpleBind(mx.Device, args))
            {
                var argNames = Model.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 1; iter <= epochs; iter++)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();
                    Metric.Reset();
                    TrainMetric.Reset();
                    sw.Restart();

                    while (train.IterNext())
                    {
                        samples += batchSize;
                        var dataBatch = train.Next();

                        // Set data and label
                        dataBatch.Data[0].CopyTo(args["X"]);
                        dataBatch.Label[0].CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        TrainMetric.Update(args[labelName], exec.Output);

                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i], null);
                        }
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.IterNext())
                        {
                            var dataBatch = validation.Next();
                            dataBatch.Data[0].CopyTo(args["X"]);
                            dataBatch.Label[0].CopyTo(args[labelName]);
                            NDArray.WaitAll();
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(args[labelName], exec.Output);
                        }
                    }

                    var duration = sw.ElapsedMilliseconds == 0 ? 1 : sw.ElapsedMilliseconds;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec Train_Metric: {TrainMetric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec, Train_Metric: {TrainMetric.Get()}, Val_Metric: {Metric.Get()}");
                    }
                }
            }

            //MXNet.MXNotifyShutdown();
        }
示例#2
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            var    args      = new SortedDictionary <string, NDArray>();
            string labelName = "label";
            var    label     = Symbol.Variable(labelName);

            args["X"]       = new NDArray(new Shape(batchSize, (uint)InputShape[0]));
            args[labelName] = new NDArray(new Shape(batchSize, (uint)OutputShape.Size));

            CompiledModel.InferArgsMap(GlobalParam.Device, args, args);

            var initializer = new SiaDNN.Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                initializer.Operator(arg.Key, arg.Value);
            }

            ModelOptimizer.SetParam("rescale_grad", 1.0 / batchSize);

            using (var exec = CompiledModel.SimpleBind(GlobalParam.Device, args))
            {
                var argNames = CompiledModel.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 0; iter < epochs; ++iter)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();

                    sw.Restart();

                    while (train.Next())
                    {
                        samples += batchSize;
                        var dataBatch = train.GetDataBatch();
                        // Set data and label
                        dataBatch.Data.CopyTo(args["X"]);
                        dataBatch.Label.CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i]);
                        }

                        Metric.Update(dataBatch.Label, exec.Outputs[0]);
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.Next())
                        {
                            var dataBatch = validation.GetDataBatch();
                            dataBatch.Data.CopyTo(args["X"]);
                            dataBatch.Label.CopyTo(args[labelName]);
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(dataBatch.Label, exec.Outputs[0]);
                        }
                    }


                    var duration = sw.ElapsedMilliseconds / 1000.0;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec Train_Metric: {Metric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec, Train_Metric: {Metric.Get()},  Val_Metric: {Metric.Get()}");
                    }
                }
            }

            MXNet.MXNotifyShutdown();
        }