Beispiel #1
0
        public void Run(TrainingSession session)
        {
            if (session.Iteration % Step != 0)
            {
                return;
            }

            Func.Invoke(session);
        }
Beispiel #2
0
        public static IEnumerable <TrainingProgress> Start(TrainingSession session, int maxIteration, int progressOutputStep, Logger logger)
        {
            int sampleCount = 0;
            var loss        = 0.0;
            var metric      = 0.0;

            var oldControlCTreatment = Console.TreatControlCAsInput;

            Console.TreatControlCAsInput = true;

            try
            {
                foreach (var t in session.GetIterator(maxIteration))
                {
                    if (Console.KeyAvailable)
                    {
                        var key = Console.ReadKey(true);
                        if ((key.Modifiers & ConsoleModifiers.Control) != 0 && key.Key == ConsoleKey.C)
                        {
                            break;
                        }
                    }

                    sampleCount += t.SampleCount;
                    loss        += t.Loss;
                    metric      += t.Metric;
                    if (t.Iteration % progressOutputStep == 0 || t.Iteration == maxIteration)
                    {
                        var p = new TrainingProgress();

                        p.Epoch        = t.Epoch;
                        p.Iteration    = t.Iteration;
                        p.SampleCount += sampleCount;
                        p.Loss         = loss / progressOutputStep;
                        p.Metric       = metric / progressOutputStep;
                        p.Validation   = t.GetValidationMetric();
                        p.LearningRate = t.Learner.LearningRate();
                        p.Elapsed      = t.Elapsed;

                        if (logger != null)
                        {
                            logger.Info(p, "Training Progress");
                        }

                        yield return(p);

                        loss   = 0.0;
                        metric = 0.0;
                    }
                }
            }
            finally
            {
                Console.TreatControlCAsInput = oldControlCTreatment;
            }
        }
 protected override void EndProcessing()
 {
     try
     {
         var session = new TrainingSession(Model, LossFunction, EvaluationFunction, Learner, LearningScheduler, Sampler, ValidationSampler, DataToInputMap, null, null, Callbacks);
         foreach (var progress in TrainingLoop.Start(session, MaxIteration, ProgressOutputStep, Logger))
         {
             WriteObject(progress);
         }
     }
     finally
     {
         Sampler.Dispose();
         ValidationSampler.Dispose();
     }
 }
        protected override void EndProcessing()
        {
            var session = new TrainingSession(Model, LossFunction, EvaluationFunction, Learner, LearningScheduler, Sampler, ValidationSampler, DataToInputMap, null, null, Callbacks);

            WriteObject(session);
        }