public override void Train(string datasetFile) { var dataSource = CreateDataSource(datasetFile); var trainer = CreateTrainer(); const int minibatchSize = 4; const int maxMinibatches = 100; var minibatchesSeen = 0; while (true) { var minibatchData = dataSource.MinibatchSource.GetNextMinibatch(minibatchSize, _device); var arguments = new Dictionary <Variable, MinibatchData> { { _modelWrapper.Input, minibatchData[dataSource.FeatureStreamInfo] }, { _modelWrapper.TrainingOutput, minibatchData[dataSource.LabelStreamInfo] } }; trainer.TrainMinibatch(arguments, _device); var trainingProgressResponse = new TrainingProgressResponse { MinibatchesSeen = minibatchesSeen, Loss = trainer.PreviousMinibatchLossAverage(), EvaluationCriterion = trainer.PreviousMinibatchEvaluationAverage() }; OnTrainingIterationPerformed(trainingProgressResponse); minibatchesSeen++; if (minibatchesSeen >= maxMinibatches) { OnTrainingFinished(new TrainingResultResponse { NewModelData = _modelWrapper.Model.Save(), Progress = trainingProgressResponse }); break; } } }
protected void OnTrainingIterationPerformed(TrainingProgressResponse trainingProgressResponse) => TrainingIterationPerformed?.Invoke(this, trainingProgressResponse);
private static void TrainingIterationPerformed( object sender, TrainingProgressResponse trainingProgressResponse) => WriteLine( $"Minibatch: {trainingProgressResponse.MinibatchesSeen} " + $"CrossEntropyLoss = {trainingProgressResponse.Loss} " + $"EvaluationCriterion = {trainingProgressResponse.EvaluationCriterion}");