Пример #1
0
        public void Mnist([DefaultValue(true)] bool gui)
        {
            const int batchSize = 128;
            const int hSize     = 20;

            MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance);

            string dataDir = Path.Combine(Path.GetTempPath(), "Retia_datasets", "MNIST");

            DownloadDataset(dataDir);

            Console.WriteLine("Loading training set");
            var trainSet = LoadTrainingSet(dataDir);

            trainSet.BatchSize = batchSize;

            var network = new LayeredNet <float>(batchSize, 1,
                                                 new AffineLayer <float>(trainSet.InputSize, hSize, AffineActivation.Sigmoid),
                                                 new LinearLayer <float>(hSize, trainSet.TargetSize),
                                                 new SoftMaxLayer <float>(trainSet.TargetSize));
            var optimizer = new AdamOptimizer <float>();

            network.Optimizer = optimizer;

            var trainer = new OptimizingTrainer <float>(network, optimizer, trainSet,
                                                        new OptimizingTrainerOptions(1)
            {
                ErrorFilterSize = 100,
                MaxEpoch        = 1,
                ProgressWriter  = ConsoleProgressWriter.Instance,
                ReportProgress  = new EachIteration(100),
                ReportMesages   = true
            }, new OptimizingSession("MNIST"));

            RetiaGui retiaGui;

            if (gui)
            {
                retiaGui = new RetiaGui();
                retiaGui.RunAsync(() => new TrainingWindow(new TypedTrainingModel <float>(trainer)));
            }

            var runner = ConsoleRunner.Create(trainer, network);

            runner.Run();
        }
Пример #2
0
        public void Xor()
        {
            MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance);

            var optimizer = new RMSPropOptimizer <float>(1e-3f);
            var net       = new LayeredNet <float>(1, 1, new AffineLayer <float>(2, 3, AffineActivation.Tanh), new AffineLayer <float>(3, 1, AffineActivation.Tanh)
            {
                ErrorFunction = new MeanSquareError <float>()
            })
            {
                Optimizer = optimizer
            };

            var trainer = new OptimizingTrainer <float>(net, optimizer, new XorDataset(true), new OptimizingTrainerOptions(1)
            {
                ErrorFilterSize    = 0,
                ReportProgress     = new EachIteration(1),
                ReportMesages      = true,
                ProgressWriter     = ConsoleProgressWriter.Instance,
                LearningRateScaler = new ProportionalLearningRateScaler(new EachIteration(1), 9e-5f)
            }, new OptimizingSession("XOR"));

            var runner = ConsoleRunner.Create(trainer, net);

            trainer.TrainReport += (sender, args) =>
            {
                if (args.Errors.Last().RawError < 1e-7f)
                {
                    runner.Stop();
                    Console.WriteLine("Finished training.");
                }
            };

            var gui = new RetiaGui();

            gui.RunAsync(() => new TrainingWindow(new TypedTrainingModel <float>(trainer)));

            runner.Run();
        }
Пример #3
0
        public void Learn(
            [Aliases("b"), Required]        string batchesPath,
            [Aliases("c")]                          string configPath,
            [Aliases("r"), DefaultValue(0.0002f)]  float learningRate,
            [DefaultValue(false)]       bool gpu,
            [DefaultValue(true)]       bool gui)
        {
            MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance);

            Console.WriteLine($"Loading test set from {batchesPath}");
            _dataProvider.Load(batchesPath);

            var optimizer = new RMSPropOptimizer <float>(learningRate, 0.95f, 0.0f, 0.9f);
            LayeredNet <float> network;

            if (string.IsNullOrEmpty(configPath))
            {
                network           = CreateNetwork(_dataProvider.TrainingSet.InputSize, 128, _dataProvider.TrainingSet.TargetSize);
                network.Optimizer = optimizer;
            }
            else
            {
                network           = CreateNetwork(configPath);
                network.Optimizer = optimizer;
            }
            network.ResetOptimizer();

            if (gpu)
            {
                Console.WriteLine("Brace yourself for GPU!");
                network.UseGpu();
            }

            var trainOptions = new OptimizingTrainerOptions(SEQ_LEN)
            {
                ErrorFilterSize = 100,
                ReportMesages   = true,
                MaxEpoch        = 1000,
                ProgressWriter  = ConsoleProgressWriter.Instance,
                ReportProgress  = new EachIteration(10)
            };

            trainOptions.LearningRateScaler = new ProportionalLearningRateScaler(new ActionSchedule(1, PeriodType.Iteration), 9.9e-5f);

            var session = new OptimizingSession(Path.GetFileNameWithoutExtension(batchesPath));
            var trainer = new OptimizingTrainer <float>(network, optimizer, _dataProvider.TrainingSet, trainOptions, session);

            RetiaGui retiaGui;
            TypedTrainingModel <float> model = null;

            if (gui)
            {
                retiaGui = new RetiaGui();
                retiaGui.RunAsync(() =>
                {
                    model = new TypedTrainingModel <float>(trainer);
                    return(new TrainingWindow(model));
                });
            }

            var epochWatch = new Stopwatch();

            trainer.EpochReached += sess =>
            {
                epochWatch.Stop();
                Console.WriteLine($"Trained epoch in {epochWatch.Elapsed.TotalSeconds} s.");

                // Showcasing plot export
                if (model != null)
                {
                    using (var stream = new MemoryStream())
                    {
                        model.ExportErrorPlot(stream, 600, 400);

                        stream.Seek(0, SeekOrigin.Begin);

                        session.AddFileToReport("ErrorPlots\\plot.png", stream);
                    }
                }

                epochWatch.Restart();
            };
            trainer.PeriodicActions.Add(new UserAction(new ActionSchedule(100, PeriodType.Iteration), () =>
            {
                if (gpu)
                {
                    network.TransferStateToHost();
                }

                string text = TestRNN(network.Clone(1, SEQ_LEN), 500, _dataProvider.Vocab);
                Console.WriteLine(text);
                session.AddFileToReport("Generated\\text.txt", text);

                trainOptions.ProgressWriter.ItemComplete();
            }));

            var runner = ConsoleRunner.Create(trainer, network);

            epochWatch.Start();
            runner.Run();
        }