예제 #1
0
        private Function PerformTraining(TrainingConfig config)
        {
            //  Setup the criteria (loss and metric)
            var crossEntropy = CNTKLib.CrossEntropyWithSoftmax(model, inputModel.LabelSequence);
            var errors       = CNTKLib.ClassificationError(model, inputModel.LabelSequence);

            //  Instantiate the trainer object to drive the model training
            var learningRatePerSample = new TrainingParameterScheduleDouble(config.LearningRate, 1);
            var momentumTimeConstant  = CNTKLib.MomentumAsTimeConstantSchedule(config.MomentumTimeConstant);
            var additionalParameters  = new AdditionalLearningOptions
            {
                gradientClippingThresholdPerSample = 5.0,
                gradientClippingWithTruncation     = true
            };
            var learner = Learner.MomentumSGDLearner(model.Parameters(), learningRatePerSample, momentumTimeConstant, true, additionalParameters);

            trainer = Trainer.CreateTrainer(model, crossEntropy, errors, new List <Learner>()
            {
                learner
            });

            for (int i = 0; i < config.Epochs; i++)
            {
                TrainMinibatch(config);
            }

            return(model);
        }
예제 #2
0
        public Function PerformTraining(TrainingConfig config, TrainingDataSource source)
        {
            this.source = source;

            var modelSequence = CreateModel(source.SymbolsCount, 2, 256);

            inputModel = CreateInputs(source.SymbolsCount);
            model      = modelSequence(inputModel.InputSequence);

            return(PerformTraining(config));
        }
예제 #3
0
        private void TrainMinibatch(TrainingConfig config)
        {
            uint minibatchesPerEpoch = (uint)Math.Min(source.Length / config.MinibatchSize, config.MaxNumberOfMinibatches / config.Epochs);

            for (int j = 0; j < minibatchesPerEpoch; j++)
            {
                var trainingData = source.GetData(j);

                var features = Value.CreateSequence <float>(inputModel.InputSequence.Shape,
                                                            trainingData.InputSequence, device);

                var arguments = new Dictionary <Variable, Value>();
                arguments.Add(inputModel.InputSequence, features);

                var labels = Value.CreateSequence(inputModel.LabelSequence.Shape,
                                                  trainingData.OutputSequence, device);

                arguments.Add(inputModel.LabelSequence, labels);
                trainer.TrainMinibatch(arguments, device);
            }
        }
예제 #4
0
 public TextDataSource(TrainingConfig config, string filePath)
     : base(config)
 {
     LoadData(filePath);
 }
예제 #5
0
 public TrainingDataSource(TrainingConfig config)
 {
     MinibatchLength = config.MinibatchSize;
 }