public void Mnist([DefaultValue(true)] bool gui) { const int batchSize = 128; const int hSize = 20; MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance); string dataDir = Path.Combine(Path.GetTempPath(), "Retia_datasets", "MNIST"); DownloadDataset(dataDir); Console.WriteLine("Loading training set"); var trainSet = LoadTrainingSet(dataDir); trainSet.BatchSize = batchSize; var network = new LayeredNet <float>(batchSize, 1, new AffineLayer <float>(trainSet.InputSize, hSize, AffineActivation.Sigmoid), new LinearLayer <float>(hSize, trainSet.TargetSize), new SoftMaxLayer <float>(trainSet.TargetSize)); var optimizer = new AdamOptimizer <float>(); network.Optimizer = optimizer; var trainer = new OptimizingTrainer <float>(network, optimizer, trainSet, new OptimizingTrainerOptions(1) { ErrorFilterSize = 100, MaxEpoch = 1, ProgressWriter = ConsoleProgressWriter.Instance, ReportProgress = new EachIteration(100), ReportMesages = true }, new OptimizingSession("MNIST")); RetiaGui retiaGui; if (gui) { retiaGui = new RetiaGui(); retiaGui.RunAsync(() => new TrainingWindow(new TypedTrainingModel <float>(trainer))); } var runner = ConsoleRunner.Create(trainer, network); runner.Run(); }
public void Xor() { MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance); var optimizer = new RMSPropOptimizer <float>(1e-3f); var net = new LayeredNet <float>(1, 1, new AffineLayer <float>(2, 3, AffineActivation.Tanh), new AffineLayer <float>(3, 1, AffineActivation.Tanh) { ErrorFunction = new MeanSquareError <float>() }) { Optimizer = optimizer }; var trainer = new OptimizingTrainer <float>(net, optimizer, new XorDataset(true), new OptimizingTrainerOptions(1) { ErrorFilterSize = 0, ReportProgress = new EachIteration(1), ReportMesages = true, ProgressWriter = ConsoleProgressWriter.Instance, LearningRateScaler = new ProportionalLearningRateScaler(new EachIteration(1), 9e-5f) }, new OptimizingSession("XOR")); var runner = ConsoleRunner.Create(trainer, net); trainer.TrainReport += (sender, args) => { if (args.Errors.Last().RawError < 1e-7f) { runner.Stop(); Console.WriteLine("Finished training."); } }; var gui = new RetiaGui(); gui.RunAsync(() => new TrainingWindow(new TypedTrainingModel <float>(trainer))); runner.Run(); }
public void Learn( [Aliases("b"), Required] string batchesPath, [Aliases("c")] string configPath, [Aliases("r"), DefaultValue(0.0002f)] float learningRate, [DefaultValue(false)] bool gpu, [DefaultValue(true)] bool gui) { MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance); Console.WriteLine($"Loading test set from {batchesPath}"); _dataProvider.Load(batchesPath); var optimizer = new RMSPropOptimizer <float>(learningRate, 0.95f, 0.0f, 0.9f); LayeredNet <float> network; if (string.IsNullOrEmpty(configPath)) { network = CreateNetwork(_dataProvider.TrainingSet.InputSize, 128, _dataProvider.TrainingSet.TargetSize); network.Optimizer = optimizer; } else { network = CreateNetwork(configPath); network.Optimizer = optimizer; } network.ResetOptimizer(); if (gpu) { Console.WriteLine("Brace yourself for GPU!"); network.UseGpu(); } var trainOptions = new OptimizingTrainerOptions(SEQ_LEN) { ErrorFilterSize = 100, ReportMesages = true, MaxEpoch = 1000, ProgressWriter = ConsoleProgressWriter.Instance, ReportProgress = new EachIteration(10) }; trainOptions.LearningRateScaler = new ProportionalLearningRateScaler(new ActionSchedule(1, PeriodType.Iteration), 9.9e-5f); var session = new OptimizingSession(Path.GetFileNameWithoutExtension(batchesPath)); var trainer = new OptimizingTrainer <float>(network, optimizer, _dataProvider.TrainingSet, trainOptions, session); RetiaGui retiaGui; TypedTrainingModel <float> model = null; if (gui) { retiaGui = new RetiaGui(); retiaGui.RunAsync(() => { model = new TypedTrainingModel <float>(trainer); return(new TrainingWindow(model)); }); } var epochWatch = new Stopwatch(); trainer.EpochReached += sess => { epochWatch.Stop(); Console.WriteLine($"Trained epoch in {epochWatch.Elapsed.TotalSeconds} s."); // Showcasing plot export if (model != null) { using (var stream = new MemoryStream()) { model.ExportErrorPlot(stream, 600, 400); stream.Seek(0, SeekOrigin.Begin); session.AddFileToReport("ErrorPlots\\plot.png", stream); } } epochWatch.Restart(); }; trainer.PeriodicActions.Add(new UserAction(new ActionSchedule(100, PeriodType.Iteration), () => { if (gpu) { network.TransferStateToHost(); } string text = TestRNN(network.Clone(1, SEQ_LEN), 500, _dataProvider.Vocab); Console.WriteLine(text); session.AddFileToReport("Generated\\text.txt", text); trainOptions.ProgressWriter.ItemComplete(); })); var runner = ConsoleRunner.Create(trainer, network); epochWatch.Start(); runner.Run(); }