Exemple #1
0
        static void Main(string[] args)
        {
            // unzip archive
            if (!Directory.Exists("cat"))
            {
                Console.WriteLine("Unpacking data....");
                DataUtil.Unzip(@"..\..\..\..\..\catsanddogs.zip", ".");
            }

            // create mapping files
            if (!File.Exists("train_map.txt"))
            {
                Console.WriteLine("Creating mapping files...");
                CreateMapFiles();
            }

            // get a training and validation image reader
            var trainingReader   = DataUtil.GetImageReader("train_map.txt", imageWidth, imageHeight, numChannels, 2, randomizeData: true, augmentData: true);
            var validationReader = DataUtil.GetImageReader("validation_map.txt", imageWidth, imageHeight, numChannels, 2, randomizeData: false, augmentData: false);

            // build features and labels
            var features = NetUtil.Var(new int[] { imageHeight, imageWidth, numChannels }, DataType.Float);
            var labels   = NetUtil.Var(new int[] { 2 }, DataType.Float);

            // ******************
            // ADD YOUR CODE HERE
            // ******************

            CNTK.Function network = null; // fix this line!

            // print the network to the console
            Console.WriteLine("Neural Network architecture: ");
            Console.WriteLine(network.ToSummary());

            // set up the loss function and the classification error function
            var lossFunc  = CNTKLib.CrossEntropyWithSoftmax(network.Output, labels);
            var errorFunc = CNTKLib.ClassificationError(network.Output, labels);

            // use the Adam learning algorithm
            var learner = network.GetAdamLearner(
                learningRateSchedule: (0.0001, 1),
                momentumSchedule: (0.99, 1));

            // set up a trainer and an evaluator
            var trainer   = network.GetTrainer(learner, lossFunc, errorFunc);
            var evaluator = network.GetEvaluator(errorFunc);

            // declare some variables
            var result      = 0.0;
            var sampleCount = 0;
            var batchCount  = 0;
            var lines       = new List <List <double> >()
            {
                new List <double>(), new List <double>()
            };

            // train the network during several epochs
            Console.WriteLine("Training the neural network....");
            for (int epoch = 0; epoch < maxEpochs; epoch++)
            {
                Console.Write($"[{DateTime.Now:HH:mm:ss}] Training epoch {epoch+1}/{maxEpochs}... ");

                // train the network using random batches
                result      = 0.0;
                sampleCount = 0;
                batchCount  = 0;
                while (sampleCount < 2 * trainingSetSize)
                {
                    // get the current batch
                    var batch         = trainingReader.GetBatch(batchSize);
                    var featuresBatch = batch[trainingReader.StreamInfo("features")];
                    var labelsBatch   = batch[trainingReader.StreamInfo("labels")];

                    // train the network on the batch
                    var(Loss, Evaluation) = trainer.TrainBatch(
                        new[] {
Exemple #2
0
        /// <summary>
        /// Train the model.
        /// </summary>
        /// <param name="threshold"></param>
        public void Train(double threshold = 0)
        {
            // create model and variables
            features = CreateFeatureVariable();
            labels   = CreateLabelVariable();
            Model    = CreateModel(features);
            AssertSequenceLength();

            // set up loss function
            CNTK.Function lossFunction = null;
            switch (lossFunctionType)
            {
            case LossFunctionType.BinaryCrossEntropy: lossFunction = CNTK.CNTKLib.BinaryCrossEntropy(Model, labels); break;

            case LossFunctionType.MSE: lossFunction = CNTK.CNTKLib.SquaredError(Model, labels); break;

            case LossFunctionType.CrossEntropyWithSoftmax: lossFunction = CNTK.CNTKLib.CrossEntropyWithSoftmax(Model, labels); break;

            case LossFunctionType.Custom: lossFunction = CustomLossFunction(); break;
            }

            // set up accuracy function
            CNTK.Function accuracy_function = null;
            switch (accuracyFunctionType)
            {
            case AccuracyFunctionType.SameAsLoss: accuracy_function = lossFunction; break;

            case AccuracyFunctionType.BinaryAccuracy: accuracy_function = NetUtil.BinaryAccuracy(Model, labels); break;
            }

            // set up an adam learner
            var learner = Model.GetAdamLearner(
                (LearningRate, (uint)BatchSize), // remove batch_size?
                (0.9, (uint)BatchSize),          // remove batch_size?
                unitGain: false);

            // set up trainer
            trainer = CNTK.CNTKLib.CreateTrainer(Model, lossFunction, accuracy_function, new CNTK.LearnerVector()
            {
                learner
            });

            // set up a scheduler to tweak the learning rate
            scheduler = new ReduceLROnPlateau(learner, LearningRate);

            // set up an evaluator
            if (validationFeatures != null)
            {
                evaluator = CNTK.CNTKLib.CreateEvaluator(accuracy_function);
            }

            // write the model summary
            Console.WriteLine("  Model architecture:");
            Console.WriteLine(Model.ToSummary());

            // clear the training curves
            TrainingCurves[0].Clear();
            TrainingCurves[1].Clear();

            // train for a certain number of epochs
            for (int epoch = 0; epoch < NumberOfEpochs; epoch++)
            {
                var epoch_start_time = DateTime.Now;

                // train and evaluate the model
                var epoch_training_metric     = TrainBatches();
                var epoch_validation_accuracy = EvaluateBatches();

                // add to training curve
                TrainingCurves[0].Add(epoch_training_metric);
                TrainingCurves[1].Add(epoch_validation_accuracy);

                // write current loss and accuracy
                var elapsedTime = DateTime.Now.Subtract(epoch_start_time);
                if (metricType == MetricType.Accuracy)
                {
                    Console.WriteLine($"Epoch {epoch + 1:D2}/{NumberOfEpochs}, Elapsed time: {elapsedTime.TotalSeconds:F3} seconds. " +
                                      $"Training Accuracy: {epoch_training_metric:F3}. Validation Accuracy: {epoch_validation_accuracy:F3}.");
                }
                else
                {
                    Console.WriteLine($"Epoch {epoch + 1:D2}/{NumberOfEpochs}, Elapsed time: {elapsedTime.TotalSeconds:F3} seconds, Training Loss: {epoch_training_metric:F3}");
                }

                // abort training if scheduler says so
                if (scheduler.Update(epoch_training_metric))
                {
                    break;
                }
                if ((threshold != 0) && (epoch_training_metric < threshold))
                {
                    break;
                }
            }
        }