private Function PerformTraining(TrainingConfig config) { // Setup the criteria (loss and metric) var crossEntropy = CNTKLib.CrossEntropyWithSoftmax(model, inputModel.LabelSequence); var errors = CNTKLib.ClassificationError(model, inputModel.LabelSequence); // Instantiate the trainer object to drive the model training var learningRatePerSample = new TrainingParameterScheduleDouble(config.LearningRate, 1); var momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(config.MomentumTimeConstant); var additionalParameters = new AdditionalLearningOptions { gradientClippingThresholdPerSample = 5.0, gradientClippingWithTruncation = true }; var learner = Learner.MomentumSGDLearner(model.Parameters(), learningRatePerSample, momentumTimeConstant, true, additionalParameters); trainer = Trainer.CreateTrainer(model, crossEntropy, errors, new List <Learner>() { learner }); for (int i = 0; i < config.Epochs; i++) { TrainMinibatch(config); } return(model); }
public Function PerformTraining(TrainingConfig config, TrainingDataSource source) { this.source = source; var modelSequence = CreateModel(source.SymbolsCount, 2, 256); inputModel = CreateInputs(source.SymbolsCount); model = modelSequence(inputModel.InputSequence); return(PerformTraining(config)); }
private void TrainMinibatch(TrainingConfig config) { uint minibatchesPerEpoch = (uint)Math.Min(source.Length / config.MinibatchSize, config.MaxNumberOfMinibatches / config.Epochs); for (int j = 0; j < minibatchesPerEpoch; j++) { var trainingData = source.GetData(j); var features = Value.CreateSequence <float>(inputModel.InputSequence.Shape, trainingData.InputSequence, device); var arguments = new Dictionary <Variable, Value>(); arguments.Add(inputModel.InputSequence, features); var labels = Value.CreateSequence(inputModel.LabelSequence.Shape, trainingData.OutputSequence, device); arguments.Add(inputModel.LabelSequence, labels); trainer.TrainMinibatch(arguments, device); } }
public TextDataSource(TrainingConfig config, string filePath) : base(config) { LoadData(filePath); }
public TrainingDataSource(TrainingConfig config) { MinibatchLength = config.MinibatchSize; }