public void Run(TrainingSession session) { if (session.Iteration % Step != 0) { return; } Func.Invoke(session); }
public static IEnumerable <TrainingProgress> Start(TrainingSession session, int maxIteration, int progressOutputStep, Logger logger) { int sampleCount = 0; var loss = 0.0; var metric = 0.0; var oldControlCTreatment = Console.TreatControlCAsInput; Console.TreatControlCAsInput = true; try { foreach (var t in session.GetIterator(maxIteration)) { if (Console.KeyAvailable) { var key = Console.ReadKey(true); if ((key.Modifiers & ConsoleModifiers.Control) != 0 && key.Key == ConsoleKey.C) { break; } } sampleCount += t.SampleCount; loss += t.Loss; metric += t.Metric; if (t.Iteration % progressOutputStep == 0 || t.Iteration == maxIteration) { var p = new TrainingProgress(); p.Epoch = t.Epoch; p.Iteration = t.Iteration; p.SampleCount += sampleCount; p.Loss = loss / progressOutputStep; p.Metric = metric / progressOutputStep; p.Validation = t.GetValidationMetric(); p.LearningRate = t.Learner.LearningRate(); p.Elapsed = t.Elapsed; if (logger != null) { logger.Info(p, "Training Progress"); } yield return(p); loss = 0.0; metric = 0.0; } } } finally { Console.TreatControlCAsInput = oldControlCTreatment; } }
protected override void EndProcessing() { try { var session = new TrainingSession(Model, LossFunction, EvaluationFunction, Learner, LearningScheduler, Sampler, ValidationSampler, DataToInputMap, null, null, Callbacks); foreach (var progress in TrainingLoop.Start(session, MaxIteration, ProgressOutputStep, Logger)) { WriteObject(progress); } } finally { Sampler.Dispose(); ValidationSampler.Dispose(); } }
protected override void EndProcessing() { var session = new TrainingSession(Model, LossFunction, EvaluationFunction, Learner, LearningScheduler, Sampler, ValidationSampler, DataToInputMap, null, null, Callbacks); WriteObject(session); }