/// <summary> /// TrainAndEvaluateWithFlowerData shows how to do transfer learning with a MinibatchSource. MinibatchSource is constructed with /// a map file that contains image file paths and labels. Data loading, image preprocessing, and batch randomization are handled /// by MinibatchSource. /// </summary> /// <param name="device">CPU or GPU device to run</param> /// <param name="forceReTrain">Force to train the model if true. If false, /// it only evaluates the model is it exists. </param> public static void TrainAndEvaluateWithFlowerData(DeviceDescriptor device, bool forceReTrain = false) { string flowerFolder = Path.Combine(ExampleImageFoler, "Flowers"); string flowersTrainingMap = Path.Combine(flowerFolder, "1k_img_map.txt"); string flowersValidationMap = Path.Combine(flowerFolder, "val_map.txt"); int flowerModelNumClasses = 102; string flowerModelFile = Path.Combine(CurrentFolder, "FlowersTransferLearning.model"); // If the model exists and it is not set to force retrain, validate the model and return. if (File.Exists(flowerModelFile) && !forceReTrain) { ValidateModelWithMinibatchSource(flowerModelFile, flowersValidationMap, imageDims, flowerModelNumClasses, device); return; } // prepare training data MinibatchSource minibatchSource = CreateMinibatchSource(flowersTrainingMap, imageDims, flowerModelNumClasses); var featureStreamInfo = minibatchSource.StreamInfo("image"); var labelStreamInfo = minibatchSource.StreamInfo("labels"); string predictionNodeName = "prediction"; Variable imageInput, labelInput; Function trainingLoss, predictionError; // create a transfer model Function transferLearningModel = CreateTransferLearningModel(BaseResnetModelFile, featureNodeName, predictionNodeName, lastHiddenNodeName, flowerModelNumClasses, device, out imageInput, out labelInput, out trainingLoss, out predictionError); // prepare for training int numMinibatches = 100; int minibatchbSize = 50; float learningRatePerMinibatch = 0.2F; float momentumPerMinibatch = 0.9F; float l2RegularizationWeight = 0.05F; AdditionalLearningOptions additionalLearningOptions = new AdditionalLearningOptions() { l2RegularizationWeight = l2RegularizationWeight }; IList <Learner> parameterLearners = new List <Learner>() { Learner.MomentumSGDLearner(transferLearningModel.Parameters(), new TrainingParameterScheduleDouble(learningRatePerMinibatch, 0), new TrainingParameterScheduleDouble(momentumPerMinibatch, 0), true, additionalLearningOptions) }; var trainer = Trainer.CreateTrainer(transferLearningModel, trainingLoss, predictionError, parameterLearners); // train the model int outputFrequencyInMinibatches = 1; for (int minibatchCount = 0; minibatchCount < numMinibatches; ++minibatchCount) { var minibatchData = minibatchSource.GetNextMinibatch((uint)minibatchbSize, device); trainer.TrainMinibatch(new Dictionary <Variable, MinibatchData>() { { imageInput, minibatchData[featureStreamInfo] }, { labelInput, minibatchData[labelStreamInfo] } }, device); TestHelper.PrintTrainingProgress(trainer, minibatchCount, outputFrequencyInMinibatches); } // save the model transferLearningModel.Save(flowerModelFile); // validate the trained model ValidateModelWithMinibatchSource(flowerModelFile, flowersValidationMap, imageDims, flowerModelNumClasses, device); }
/// <summary> /// TrainAndEvaluateWithAnimalData shows how to do transfer learning without using a MinibatchSource. /// Training and evaluation data are prepared as appropriate in the code. /// Batching is done explicitly in the code as well. /// Because the amount of animal data is limited, it is fine to work this way. /// In real scenarios, it is recommended to code efficiently for data preprocessing and batching, /// probably with parallelization and streaming as what has been done in MinibatchSource. /// </summary> /// <param name="device">CPU or GPU device to run</param> /// <param name="forceRetrain">Force to train the model if true. If false, /// it only evaluates the model is it exists. </param> public static void TrainAndEvaluateWithAnimalData(DeviceDescriptor device, bool forceRetrain = false) { string animalDataFolder = Path.Combine(ExampleImageFoler, "Animals"); string[] animals = new string[] { "Sheep", "Wolf" }; int animalModelNumClasses = 2; string animalsModelFile = Path.Combine(CurrentFolder, "AnimalsTransferLearning.model"); // If the model exists and it is not set to force retrain, validate the model and return. if (File.Exists(animalsModelFile) && !forceRetrain) { ValidateModelWithoutMinibatchSource(animalsModelFile, Path.Combine(animalDataFolder, "Test"), animals, imageDims, animalModelNumClasses, device); return; } List <Tuple <string, int, float[]> > trainingDataMap = PrepareTrainingDataFromSubfolders(Path.Combine(animalDataFolder, "Train"), animals, imageDims); // prepare the transfer model string predictionNodeName = "prediction"; Variable imageInput, labelInput; Function trainingLoss, predictionError; Function transferLearningModel = CreateTransferLearningModel(Path.Combine(ExampleImageFoler, BaseResnetModelFile), featureNodeName, predictionNodeName, lastHiddenNodeName, animalModelNumClasses, device, out imageInput, out labelInput, out trainingLoss, out predictionError); // prepare for training int numMinibatches = 5; float learningRatePerMinibatch = 0.2F; float learningmomentumPerMinibatch = 0.9F; float l2RegularizationWeight = 0.1F; AdditionalLearningOptions additionalLearningOptions = new AdditionalLearningOptions() { l2RegularizationWeight = l2RegularizationWeight }; IList <Learner> parameterLearners = new List <Learner>() { Learner.MomentumSGDLearner(transferLearningModel.Parameters(), new TrainingParameterScheduleDouble(learningRatePerMinibatch, 0), new TrainingParameterScheduleDouble(learningmomentumPerMinibatch, 0), true, additionalLearningOptions) }; var trainer = Trainer.CreateTrainer(transferLearningModel, trainingLoss, predictionError, parameterLearners); // train the model for (int minibatchCount = 0; minibatchCount < numMinibatches; ++minibatchCount) { Value imageBatch, labelBatch; int batchCount = 0, batchSize = 15; while (GetImageAndLabelMinibatch(trainingDataMap, batchSize, batchCount++, imageDims, animalModelNumClasses, device, out imageBatch, out labelBatch)) { //TODO: sweepEnd should be set properly. #pragma warning disable 618 trainer.TrainMinibatch(new Dictionary <Variable, Value>() { { imageInput, imageBatch }, { labelInput, labelBatch } }, device); #pragma warning restore 618 TestHelper.PrintTrainingProgress(trainer, minibatchCount, 1); } } // save the trained model transferLearningModel.Save(animalsModelFile); // done with training, continue with validation double error = ValidateModelWithoutMinibatchSource(animalsModelFile, Path.Combine(animalDataFolder, "Test"), animals, imageDims, animalModelNumClasses, device); Console.WriteLine(error); }
/// <summary> /// Creates the learner based on learning parameters. /// ToDo: Not all learners parameters defined /// </summary> /// <param name="network">Network model being trained</param> /// <param name="lrParams">Learning parameters.</param> /// <returns></returns> private List <Learner> createLearners(Function network, LearningParameters lrParams) { //learning rate and momentum values var lr = new TrainingParameterScheduleDouble(lrParams.LearningRate); var mm = CNTKLib.MomentumAsTimeConstantSchedule(lrParams.Momentum); var addParam = new AdditionalLearningOptions(); // if (lrParams.L1Regularizer > 0) { addParam.l1RegularizationWeight = lrParams.L1Regularizer; } if (lrParams.L2Regularizer > 0) { addParam.l2RegularizationWeight = lrParams.L2Regularizer; } //SGD Momentum learner if (lrParams.LearnerType == LearnerType.MomentumSGDLearner) { // var llr = new List <Learner>(); var msgd = Learner.MomentumSGDLearner(network.Parameters(), lr, mm, true, addParam); llr.Add(msgd); return(llr); } //SGDLearner - rate and regulars else if (lrParams.LearnerType == LearnerType.SGDLearner) { // var llr = new List <Learner>(); var msgd = Learner.SGDLearner(network.Parameters(), lr, addParam); llr.Add(msgd); return(llr); } //FSAdaGradLearner learner - rate, moment regulars else if (lrParams.LearnerType == LearnerType.FSAdaGradLearner) { // var llr = new List <Learner>(); var msgd = CNTKLib.FSAdaGradLearner(new ParameterVector(network.Parameters().ToList()), lr, mm); llr.Add(msgd); return(llr); } //AdamLearner learner else if (lrParams.LearnerType == LearnerType.AdamLearner) { // var llr = new List <Learner>(); var msgd = CNTKLib.AdamLearner(new ParameterVector(network.Parameters().ToList()), lr, mm); llr.Add(msgd); return(llr); } //AdaGradLearner learner - Learning rate and regularizers else if (lrParams.LearnerType == LearnerType.AdaGradLearner) { // var llr = new List <Learner>(); var msgd = CNTKLib.AdaGradLearner(new ParameterVector(network.Parameters().ToList()), lr, false, addParam); llr.Add(msgd); return(llr); } else { throw new Exception("Learner type is not supported!"); } }
public Program() { SetDevice(); var trainingFile = "tinyshakespeare.txt"; if (!loadData(trainingFile)) { return; } Console.WriteLine($"Data { trainingFile } has { text.Length } characters, with { characters.Count } unique characters."); var inputModel = CreateInputs(characters.Count); var modelSequence = CreateModel(characters.Count, 2, 256); var model = modelSequence(inputModel.InputSequence); // var model = CreateModel(inputModel.InputSequence, characters.Count, 1, 256); // model.Save("dummymodel.dnn"); // return; // Setup the criteria (loss and metric) var crossEntropy = CNTKLib.CrossEntropyWithSoftmax(model, inputModel.LabelSequence); var errors = CNTKLib.ClassificationError(model, inputModel.LabelSequence); // Instantiate the trainer object to drive the model training var learningRatePerSample = new TrainingParameterScheduleDouble(0.001); var momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(1100); var additionalParameters = new AdditionalLearningOptions { gradientClippingThresholdPerSample = 5.0, gradientClippingWithTruncation = true }; var learner = Learner.MomentumSGDLearner(model.Parameters(), learningRatePerSample, momentumTimeConstant, false, additionalParameters); var trainer = Trainer.CreateTrainer(model, crossEntropy, errors, new List <Learner>() { learner }); var epochs = 50; var minibatchSize = 100; var maxNumberOfMinibatches = int.MaxValue; var sampleFrequency = 1000; var minibatchesPerEpoch = Math.Min(text.Length / minibatchSize, maxNumberOfMinibatches / epochs); var parameterTensor = model.Parameters(); var sumOfParameters = 0; foreach (var parameter in parameterTensor) { sumOfParameters += parameter.Shape.TotalSize; } Console.WriteLine($"Training { sumOfParameters } parameter in { parameterTensor.Count } parameter tensors"); Console.WriteLine($"Running { epochs } epochs with { minibatchesPerEpoch } minibatches per epoch"); Console.WriteLine(); for (int i = 0; i < epochs; i++) { var start = DateTime.Now; Console.WriteLine($"Running training on epoch {i + 1} of {epochs}"); for (int j = 0; j < minibatchesPerEpoch; j++) { var trainingData = GetData(j, minibatchSize, text, charToIndex, characters.Count); var arguments = new Dictionary <Variable, Value>(); var features = Value.CreateSequence <float>(inputModel.InputSequence.Shape, trainingData.InputSequence, device); arguments.Add(inputModel.InputSequence, features); var labels = Value.CreateSequence(inputModel.LabelSequence.Shape, trainingData.OutputSequence, device); arguments.Add(inputModel.LabelSequence, labels); trainer.TrainMinibatch(arguments, device); var globalMinibatch = i * minibatchesPerEpoch + j; if (globalMinibatch % sampleFrequency == 0) { Sample(model, 50); } if (globalMinibatch % 100 == 0) { var minibatchId = j + 1; var minibatchEndId = j + 100; var trainingLossValue = trainer.PreviousMinibatchLossAverage(); var evaluationValue = trainer.PreviousMinibatchEvaluationAverage(); Console.WriteLine( $"Epoch {(i + 1), 3}: Minibatch [{minibatchId, 6}-{minibatchEndId, 6}] CrossEntropyLoss = {trainingLossValue:F6}, EvaluationCriterion = {evaluationValue:F3}"); } } var end = DateTime.Now; var epochLength = end - start; Console.WriteLine( $"Finished epoch {i + 1} in {epochLength.TotalSeconds} seconds ({epochLength.Hours:00}:{epochLength.Minutes:00}:{epochLength.Seconds:00}.{epochLength.Milliseconds:000})"); var modelFilename = $"newmodels/shakespeare_epoch{ i + 1 }.dnn"; model.Save(modelFilename); Console.WriteLine($"Saved model to { modelFilename }"); } Console.ReadKey(); }
internal static global::System.Runtime.InteropServices.HandleRef getCPtr(AdditionalLearningOptions obj) { return((obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr); }