Ejemplo n.º 1
0
        /// <summary>
        /// TrainAndEvaluateWithFlowerData shows how to do transfer learning with a MinibatchSource. MinibatchSource is constructed with
        /// a map file that contains image file paths and labels. Data loading, image preprocessing, and batch randomization are handled
        /// by MinibatchSource.
        /// </summary>
        /// <param name="device">CPU or GPU device to run</param>
        /// <param name="forceReTrain">Force to train the model if true. If false,
        /// it only evaluates the model is it exists. </param>
        public static void TrainAndEvaluateWithFlowerData(DeviceDescriptor device, bool forceReTrain = false)
        {
            string flowerFolder          = Path.Combine(ExampleImageFoler, "Flowers");
            string flowersTrainingMap    = Path.Combine(flowerFolder, "1k_img_map.txt");
            string flowersValidationMap  = Path.Combine(flowerFolder, "val_map.txt");
            int    flowerModelNumClasses = 102;

            string flowerModelFile = Path.Combine(CurrentFolder, "FlowersTransferLearning.model");

            // If the model exists and it is not set to force retrain, validate the model and return.
            if (File.Exists(flowerModelFile) && !forceReTrain)
            {
                ValidateModelWithMinibatchSource(flowerModelFile, flowersValidationMap,
                                                 imageDims, flowerModelNumClasses, device);
                return;
            }

            // prepare training data
            MinibatchSource minibatchSource = CreateMinibatchSource(flowersTrainingMap,
                                                                    imageDims, flowerModelNumClasses);
            var featureStreamInfo = minibatchSource.StreamInfo("image");
            var labelStreamInfo   = minibatchSource.StreamInfo("labels");

            string   predictionNodeName = "prediction";
            Variable imageInput, labelInput;
            Function trainingLoss, predictionError;

            // create a transfer model
            Function transferLearningModel = CreateTransferLearningModel(BaseResnetModelFile, featureNodeName,
                                                                         predictionNodeName, lastHiddenNodeName, flowerModelNumClasses, device,
                                                                         out imageInput, out labelInput, out trainingLoss, out predictionError);

            // prepare for training
            int   numMinibatches           = 100;
            int   minibatchbSize           = 50;
            float learningRatePerMinibatch = 0.2F;
            float momentumPerMinibatch     = 0.9F;
            float l2RegularizationWeight   = 0.05F;

            AdditionalLearningOptions additionalLearningOptions = new AdditionalLearningOptions()
            {
                l2RegularizationWeight = l2RegularizationWeight
            };

            IList <Learner> parameterLearners = new List <Learner>()
            {
                Learner.MomentumSGDLearner(transferLearningModel.Parameters(),
                                           new TrainingParameterScheduleDouble(learningRatePerMinibatch, 0),
                                           new TrainingParameterScheduleDouble(momentumPerMinibatch, 0),
                                           true,
                                           additionalLearningOptions)
            };
            var trainer = Trainer.CreateTrainer(transferLearningModel, trainingLoss, predictionError, parameterLearners);

            // train the model
            int outputFrequencyInMinibatches = 1;

            for (int minibatchCount = 0; minibatchCount < numMinibatches; ++minibatchCount)
            {
                var minibatchData = minibatchSource.GetNextMinibatch((uint)minibatchbSize, device);

                trainer.TrainMinibatch(new Dictionary <Variable, MinibatchData>()
                {
                    { imageInput, minibatchData[featureStreamInfo] },
                    { labelInput, minibatchData[labelStreamInfo] }
                }, device);
                TestHelper.PrintTrainingProgress(trainer, minibatchCount, outputFrequencyInMinibatches);
            }

            // save the model
            transferLearningModel.Save(flowerModelFile);

            // validate the trained model
            ValidateModelWithMinibatchSource(flowerModelFile, flowersValidationMap,
                                             imageDims, flowerModelNumClasses, device);
        }
Ejemplo n.º 2
0
        /// <summary>
        /// TrainAndEvaluateWithAnimalData shows how to do transfer learning without using a MinibatchSource.
        /// Training and evaluation data are prepared as appropriate in the code.
        /// Batching is done explicitly in the code as well.
        /// Because the amount of animal data is limited, it is fine to work this way.
        /// In real scenarios, it is recommended to code efficiently for data preprocessing and batching,
        /// probably with parallelization and streaming as what has been done in MinibatchSource.
        /// </summary>
        /// <param name="device">CPU or GPU device to run</param>
        /// <param name="forceRetrain">Force to train the model if true. If false,
        /// it only evaluates the model is it exists. </param>
        public static void TrainAndEvaluateWithAnimalData(DeviceDescriptor device, bool forceRetrain = false)
        {
            string animalDataFolder = Path.Combine(ExampleImageFoler, "Animals");

            string[] animals = new string[] { "Sheep", "Wolf" };
            int      animalModelNumClasses = 2;
            string   animalsModelFile      = Path.Combine(CurrentFolder, "AnimalsTransferLearning.model");

            // If the model exists and it is not set to force retrain, validate the model and return.
            if (File.Exists(animalsModelFile) && !forceRetrain)
            {
                ValidateModelWithoutMinibatchSource(animalsModelFile, Path.Combine(animalDataFolder, "Test"), animals,
                                                    imageDims, animalModelNumClasses, device);
                return;
            }

            List <Tuple <string, int, float[]> > trainingDataMap =
                PrepareTrainingDataFromSubfolders(Path.Combine(animalDataFolder, "Train"), animals, imageDims);

            // prepare the transfer model
            string   predictionNodeName = "prediction";
            Variable imageInput, labelInput;
            Function trainingLoss, predictionError;
            Function transferLearningModel = CreateTransferLearningModel(Path.Combine(ExampleImageFoler, BaseResnetModelFile), featureNodeName, predictionNodeName,
                                                                         lastHiddenNodeName, animalModelNumClasses, device,
                                                                         out imageInput, out labelInput, out trainingLoss, out predictionError);

            // prepare for training
            int   numMinibatches               = 5;
            float learningRatePerMinibatch     = 0.2F;
            float learningmomentumPerMinibatch = 0.9F;
            float l2RegularizationWeight       = 0.1F;

            AdditionalLearningOptions additionalLearningOptions =
                new AdditionalLearningOptions()
            {
                l2RegularizationWeight = l2RegularizationWeight
            };
            IList <Learner> parameterLearners = new List <Learner>()
            {
                Learner.MomentumSGDLearner(transferLearningModel.Parameters(),
                                           new TrainingParameterScheduleDouble(learningRatePerMinibatch, 0),
                                           new TrainingParameterScheduleDouble(learningmomentumPerMinibatch, 0),
                                           true,
                                           additionalLearningOptions)
            };
            var trainer = Trainer.CreateTrainer(transferLearningModel, trainingLoss, predictionError, parameterLearners);

            // train the model
            for (int minibatchCount = 0; minibatchCount < numMinibatches; ++minibatchCount)
            {
                Value imageBatch, labelBatch;
                int   batchCount = 0, batchSize = 15;
                while (GetImageAndLabelMinibatch(trainingDataMap, batchSize, batchCount++,
                                                 imageDims, animalModelNumClasses, device, out imageBatch, out labelBatch))
                {
                    //TODO: sweepEnd should be set properly.
#pragma warning disable 618
                    trainer.TrainMinibatch(new Dictionary <Variable, Value>()
                    {
                        { imageInput, imageBatch },
                        { labelInput, labelBatch }
                    }, device);
#pragma warning restore 618
                    TestHelper.PrintTrainingProgress(trainer, minibatchCount, 1);
                }
            }

            // save the trained model
            transferLearningModel.Save(animalsModelFile);

            // done with training, continue with validation
            double error = ValidateModelWithoutMinibatchSource(animalsModelFile, Path.Combine(animalDataFolder, "Test"), animals,
                                                               imageDims, animalModelNumClasses, device);
            Console.WriteLine(error);
        }
Ejemplo n.º 3
0
        /// <summary>
        /// Creates the learner based on learning parameters.
        /// ToDo: Not all learners parameters defined
        /// </summary>
        /// <param name="network">Network model being trained</param>
        /// <param name="lrParams">Learning parameters.</param>
        /// <returns></returns>
        private List <Learner> createLearners(Function network, LearningParameters lrParams)
        {
            //learning rate and momentum values
            var lr       = new TrainingParameterScheduleDouble(lrParams.LearningRate);
            var mm       = CNTKLib.MomentumAsTimeConstantSchedule(lrParams.Momentum);
            var addParam = new AdditionalLearningOptions();

            //
            if (lrParams.L1Regularizer > 0)
            {
                addParam.l1RegularizationWeight = lrParams.L1Regularizer;
            }
            if (lrParams.L2Regularizer > 0)
            {
                addParam.l2RegularizationWeight = lrParams.L2Regularizer;
            }

            //SGD Momentum learner
            if (lrParams.LearnerType == LearnerType.MomentumSGDLearner)
            {
                //
                var llr  = new List <Learner>();
                var msgd = Learner.MomentumSGDLearner(network.Parameters(), lr, mm, true, addParam);
                llr.Add(msgd);
                return(llr);
            }
            //SGDLearner - rate and regulars
            else if (lrParams.LearnerType == LearnerType.SGDLearner)
            {
                //
                var llr  = new List <Learner>();
                var msgd = Learner.SGDLearner(network.Parameters(), lr, addParam);
                llr.Add(msgd);
                return(llr);
            }
            //FSAdaGradLearner learner - rate, moment regulars
            else if (lrParams.LearnerType == LearnerType.FSAdaGradLearner)
            {
                //
                var llr  = new List <Learner>();
                var msgd = CNTKLib.FSAdaGradLearner(new ParameterVector(network.Parameters().ToList()), lr, mm);
                llr.Add(msgd);
                return(llr);
            }
            //AdamLearner learner
            else if (lrParams.LearnerType == LearnerType.AdamLearner)
            {
                //
                var llr  = new List <Learner>();
                var msgd = CNTKLib.AdamLearner(new ParameterVector(network.Parameters().ToList()), lr, mm);
                llr.Add(msgd);
                return(llr);
            }
            //AdaGradLearner learner - Learning rate and regularizers
            else if (lrParams.LearnerType == LearnerType.AdaGradLearner)
            {
                //
                var llr  = new List <Learner>();
                var msgd = CNTKLib.AdaGradLearner(new ParameterVector(network.Parameters().ToList()), lr, false, addParam);
                llr.Add(msgd);
                return(llr);
            }
            else
            {
                throw new Exception("Learner type is not supported!");
            }
        }
Ejemplo n.º 4
0
        public Program()
        {
            SetDevice();
            var trainingFile = "tinyshakespeare.txt";

            if (!loadData(trainingFile))
            {
                return;
            }
            Console.WriteLine($"Data { trainingFile } has { text.Length } characters, with { characters.Count } unique characters.");

            var inputModel    = CreateInputs(characters.Count);
            var modelSequence = CreateModel(characters.Count, 2, 256);
            var model         = modelSequence(inputModel.InputSequence);

//            var model = CreateModel(inputModel.InputSequence, characters.Count, 1, 256);
//            model.Save("dummymodel.dnn");
//            return;

            //  Setup the criteria (loss and metric)
            var crossEntropy = CNTKLib.CrossEntropyWithSoftmax(model, inputModel.LabelSequence);
            var errors       = CNTKLib.ClassificationError(model, inputModel.LabelSequence);

            //  Instantiate the trainer object to drive the model training
            var learningRatePerSample = new TrainingParameterScheduleDouble(0.001);
            var momentumTimeConstant  = CNTKLib.MomentumAsTimeConstantSchedule(1100);
            var additionalParameters  = new AdditionalLearningOptions
            {
                gradientClippingThresholdPerSample = 5.0,
                gradientClippingWithTruncation     = true
            };
            var learner = Learner.MomentumSGDLearner(model.Parameters(), learningRatePerSample, momentumTimeConstant, false, additionalParameters);
            var trainer = Trainer.CreateTrainer(model, crossEntropy, errors, new List <Learner>()
            {
                learner
            });

            var epochs                 = 50;
            var minibatchSize          = 100;
            var maxNumberOfMinibatches = int.MaxValue;
            var sampleFrequency        = 1000;
            var minibatchesPerEpoch    = Math.Min(text.Length / minibatchSize, maxNumberOfMinibatches / epochs);
            var parameterTensor        = model.Parameters();
            var sumOfParameters        = 0;

            foreach (var parameter in parameterTensor)
            {
                sumOfParameters += parameter.Shape.TotalSize;
            }
            Console.WriteLine($"Training { sumOfParameters } parameter in { parameterTensor.Count } parameter tensors");
            Console.WriteLine($"Running { epochs } epochs with { minibatchesPerEpoch } minibatches per epoch");
            Console.WriteLine();

            for (int i = 0; i < epochs; i++)
            {
                var start = DateTime.Now;
                Console.WriteLine($"Running training on epoch {i + 1} of {epochs}");
                for (int j = 0; j < minibatchesPerEpoch; j++)
                {
                    var trainingData = GetData(j, minibatchSize, text, charToIndex, characters.Count);
                    var arguments    = new Dictionary <Variable, Value>();
                    var features     = Value.CreateSequence <float>(inputModel.InputSequence.Shape,
                                                                    trainingData.InputSequence, device);
                    arguments.Add(inputModel.InputSequence, features);
                    var labels = Value.CreateSequence(inputModel.LabelSequence.Shape,
                                                      trainingData.OutputSequence, device);
                    arguments.Add(inputModel.LabelSequence, labels);
                    trainer.TrainMinibatch(arguments, device);

                    var globalMinibatch = i * minibatchesPerEpoch + j;
                    if (globalMinibatch % sampleFrequency == 0)
                    {
                        Sample(model, 50);
                    }
                    if (globalMinibatch % 100 == 0)
                    {
                        var minibatchId       = j + 1;
                        var minibatchEndId    = j + 100;
                        var trainingLossValue = trainer.PreviousMinibatchLossAverage();
                        var evaluationValue   = trainer.PreviousMinibatchEvaluationAverage();
                        Console.WriteLine(
                            $"Epoch {(i + 1), 3}: Minibatch [{minibatchId, 6}-{minibatchEndId, 6}] CrossEntropyLoss = {trainingLossValue:F6}, EvaluationCriterion = {evaluationValue:F3}");
                    }
                }
                var end         = DateTime.Now;
                var epochLength = end - start;
                Console.WriteLine(
                    $"Finished epoch {i + 1} in {epochLength.TotalSeconds} seconds ({epochLength.Hours:00}:{epochLength.Minutes:00}:{epochLength.Seconds:00}.{epochLength.Milliseconds:000})");
                var modelFilename = $"newmodels/shakespeare_epoch{ i + 1 }.dnn";
                model.Save(modelFilename);
                Console.WriteLine($"Saved model to { modelFilename }");
            }
            Console.ReadKey();
        }
Ejemplo n.º 5
0
 internal static global::System.Runtime.InteropServices.HandleRef getCPtr(AdditionalLearningOptions obj)
 {
     return((obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr);
 }