protected override void DoAction(OptimizingSession session) { if (float.IsNaN(_initialRate)) { _initialRate = session.Optimizer.LearningRate; } _scalingTicks++; session.Optimizer.LearningRate = _initialRate / (1.0f + _scalingTicks * ScalingFactor); }
public OptimizingTrainer(NeuralNet <T> network, OptimizerBase <T> optimizer, IDataSet <T> trainingSet, OptimizingTrainerOptions options, OptimizingSession session) : base(options, session) { _network = network; _optimizer = optimizer; TrainingSet = trainingSet; // TODO: This is not very good. session.Optimizer = optimizer; session.Network = network; }
public void Learn( [Aliases("b"), Required] string batchesPath, [Aliases("c")] string configPath, [Aliases("r"), DefaultValue(0.0002f)] float learningRate, [DefaultValue(false)] bool gpu, [DefaultValue(true)] bool gui) { MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance); Console.WriteLine($"Loading test set from {batchesPath}"); _dataProvider.Load(batchesPath); var optimizer = new RMSPropOptimizer <float>(learningRate, 0.95f, 0.0f, 0.9f); LayeredNet <float> network; if (string.IsNullOrEmpty(configPath)) { network = CreateNetwork(_dataProvider.TrainingSet.InputSize, 128, _dataProvider.TrainingSet.TargetSize); network.Optimizer = optimizer; } else { network = CreateNetwork(configPath); network.Optimizer = optimizer; } network.ResetOptimizer(); if (gpu) { Console.WriteLine("Brace yourself for GPU!"); network.UseGpu(); } var trainOptions = new OptimizingTrainerOptions(SEQ_LEN) { ErrorFilterSize = 100, ReportMesages = true, MaxEpoch = 1000, ProgressWriter = ConsoleProgressWriter.Instance, ReportProgress = new EachIteration(10) }; trainOptions.LearningRateScaler = new ProportionalLearningRateScaler(new ActionSchedule(1, PeriodType.Iteration), 9.9e-5f); var session = new OptimizingSession(Path.GetFileNameWithoutExtension(batchesPath)); var trainer = new OptimizingTrainer <float>(network, optimizer, _dataProvider.TrainingSet, trainOptions, session); RetiaGui retiaGui; TypedTrainingModel <float> model = null; if (gui) { retiaGui = new RetiaGui(); retiaGui.RunAsync(() => { model = new TypedTrainingModel <float>(trainer); return(new TrainingWindow(model)); }); } var epochWatch = new Stopwatch(); trainer.EpochReached += sess => { epochWatch.Stop(); Console.WriteLine($"Trained epoch in {epochWatch.Elapsed.TotalSeconds} s."); // Showcasing plot export if (model != null) { using (var stream = new MemoryStream()) { model.ExportErrorPlot(stream, 600, 400); stream.Seek(0, SeekOrigin.Begin); session.AddFileToReport("ErrorPlots\\plot.png", stream); } } epochWatch.Restart(); }; trainer.PeriodicActions.Add(new UserAction(new ActionSchedule(100, PeriodType.Iteration), () => { if (gpu) { network.TransferStateToHost(); } string text = TestRNN(network.Clone(1, SEQ_LEN), 500, _dataProvider.Vocab); Console.WriteLine(text); session.AddFileToReport("Generated\\text.txt", text); trainOptions.ProgressWriter.ItemComplete(); })); var runner = ConsoleRunner.Create(trainer, network); epochWatch.Start(); runner.Run(); }