示例#1
0
        protected override void ProcessRecord()
        {
            IList <Value> results;

            if (ParameterSetName == "Hashtable")
            {
                results = FunctionInvoke.Invoke(Function, Arguments, Device, false);
            }
            else if (ParameterSetName == "DataSourceSet")
            {
                results = FunctionInvoke.Invoke(Function, DataSourceSet, Device, false);
            }
            else
            {
                DataNameToInputMap map   = new DataNameToInputMap(new Function[] { Function }, DataNameToInputMap);
                Minibatch          batch = null;
                var values = new List <Value>();
                do
                {
                    batch = Sampler.GetNextMinibatch(Device);
                    map.InitializeByMinibatch(batch);
                    values.AddRange(FunctionInvoke.Invoke(Function, batch, map, Device));
                }while (!batch.SweepEnd);
                results = values;
            }

            foreach (var r in results)
            {
                WriteObject(r);
            }
        }
示例#2
0
        public IEnumerable <TrainingSession> GetIterator(int maxIteration = int.MaxValue)
        {
            _stopwatch = Stopwatch.StartNew();
            _stop      = false;

            Epoch            = 1;
            EpochIncremented = false;

            if (LearningRateScheduler != null)
            {
                Learner.ResetLearningRate(new TrainingParameterScheduleDouble(LearningRateScheduler.LearningRate));
            }

            for (Iteration = 1; Iteration <= maxIteration; ++Iteration)
            {
                var minibatch = Sampler.GetNextMinibatch(TrainingDevice);
                if (minibatch == null)
                {
                    break;
                }

                DataNameToInputMap.InitializeByMinibatch(minibatch);

                var arguments = DataNameToInputMap.GetVariableValueMap(minibatch);

                Trainer.TrainMinibatch(arguments, minibatch.SweepEnd, TrainingDevice);

                SampleCount = (int)Trainer.PreviousMinibatchSampleCount();
                Loss        = Trainer.PreviousMinibatchLossAverage();
                if (Trainer.EvaluationFunction() != null)
                {
                    Metric = Trainer.PreviousMinibatchEvaluationAverage();
                }

                foreach (var cb in Callbacks)
                {
                    cb.Run(this);
                }

                if (_stop)
                {
                    break;
                }

                yield return(this);

                if (LearningRateScheduler != null)
                {
                    bool update = LearningRateScheduler.UpdateLearningRate(Epoch, Iteration, Loss);
                    if (update)
                    {
                        Learner.ResetLearningRate(new TrainingParameterScheduleDouble(LearningRateScheduler.LearningRate));
                    }
                }

                EpochIncremented = false;
                if (minibatch.SweepEnd)
                {
                    ++Epoch;
                    EpochIncremented = true;
                }
            }
        }