Ejemplo n.º 1
0
        double validation_phase(CNTK.MinibatchSource reader)
        {
            var featuresStreamInfo = reader.StreamInfo("features");
            var labelsStreamInfo   = reader.StreamInfo("labels");
            var num_samples        = 0;
            var num_minibatches    = 0;
            var score = 0.0;

            while (num_samples < 2 * validation_set_size)
            {
                num_minibatches++;
                var minibatchData = reader.GetNextMinibatch(minibatch_size, computeDevice);
                var arguments     = new test_feed_t {
                    { features_tensor, minibatchData[featuresStreamInfo] }, { label_tensor, minibatchData[labelsStreamInfo] }
                };
                num_samples += (int)(minibatchData[featuresStreamInfo].numberOfSamples);
                evaluator.TestMinibatch(arguments, computeDevice);
                score += trainer.PreviousMinibatchEvaluationAverage();
            }
            var result = 1.0 - (score / num_minibatches);

            return(result);
        }
Ejemplo n.º 2
0
        List <List <double> > train_with_augmentation(bool use_finetuning)
        {
            var labels          = CNTK.Variable.InputVariable(new int[] { 2 }, CNTK.DataType.Float, "labels");
            var features        = CNTK.Variable.InputVariable(new int[] { 150, 150, 3 }, CNTK.DataType.Float, "features");
            var scalar_factor   = CNTK.Constant.Scalar <float>((float)(1.0 / 255.0), computeDevice);
            var scaled_features = CNTK.CNTKLib.ElementTimes(scalar_factor, features);

            var conv_base = VGG16.get_model(scaled_features, computeDevice, use_finetuning);
            var model     = Util.Dense(conv_base, 256, computeDevice);

            model = CNTK.CNTKLib.ReLU(model);
            model = CNTK.CNTKLib.Dropout(model, 0.5);
            model = Util.Dense(model, 2, computeDevice);

            var loss_function     = CNTK.CNTKLib.CrossEntropyWithSoftmax(model.Output, labels);
            var accuracy_function = CNTK.CNTKLib.ClassificationError(model.Output, labels);

            var pv        = new CNTK.ParameterVector((System.Collections.ICollection)model.Parameters());
            var learner   = CNTK.CNTKLib.AdamLearner(pv, new CNTK.TrainingParameterScheduleDouble(0.0001, 1), new CNTK.TrainingParameterScheduleDouble(0.99, 1));
            var trainer   = CNTK.Trainer.CreateTrainer(model, loss_function, accuracy_function, new CNTK.Learner[] { learner });
            var evaluator = CNTK.CNTKLib.CreateEvaluator(accuracy_function);

            var train_minibatch_source      = create_minibatch_source(features.Shape, 0, 1000, "train", is_training: true, use_augmentations: true);
            var validation_minibatch_source = create_minibatch_source(features.Shape, 1000, 500, "validation", is_training: false, use_augmentations: false);

            var train_featuresStreamInformation      = train_minibatch_source.StreamInfo("features");
            var train_labelsStreamInformation        = train_minibatch_source.StreamInfo("labels");
            var validation_featuresStreamInformation = validation_minibatch_source.StreamInfo("features");
            var validation_labelsStreamInformation   = validation_minibatch_source.StreamInfo("labels");


            var training_accuracy   = new List <double>();
            var validation_accuracy = new List <double>();

            for (int epoch = 0; epoch < max_epochs; epoch++)
            {
                var startTime = DateTime.Now;

                // training phase
                var epoch_training_error = 0.0;
                var pos         = 0;
                var num_batches = 0;
                while (pos < 2000)
                {
                    var pos_end         = Math.Min(pos + batch_size, 2000);
                    var minibatch_data  = train_minibatch_source.GetNextMinibatch((uint)(pos_end - pos), computeDevice);
                    var feed_dictionary = new batch_t()
                    {
                        { features, minibatch_data[train_featuresStreamInformation] },
                        { labels, minibatch_data[train_labelsStreamInformation] }
                    };
                    trainer.TrainMinibatch(feed_dictionary, computeDevice);
                    epoch_training_error += trainer.PreviousMinibatchEvaluationAverage();
                    num_batches++;
                    pos = pos_end;
                }
                epoch_training_error /= num_batches;
                training_accuracy.Add(1.0 - epoch_training_error);

                // evaluation phase
                var epoch_validation_error = 0.0;
                num_batches = 0;
                pos         = 0;
                while (pos < 1000)
                {
                    var pos_end         = Math.Min(pos + batch_size, 1000);
                    var minibatch_data  = validation_minibatch_source.GetNextMinibatch((uint)(pos_end - pos), computeDevice);
                    var feed_dictionary = new CNTK.UnorderedMapVariableMinibatchData()
                    {
                        { features, minibatch_data[validation_featuresStreamInformation] },
                        { labels, minibatch_data[validation_labelsStreamInformation] }
                    };
                    epoch_validation_error += evaluator.TestMinibatch(feed_dictionary);
                    pos = pos_end;
                    num_batches++;
                }
                epoch_validation_error /= num_batches;
                validation_accuracy.Add(1.0 - epoch_validation_error);

                var elapsedTime = DateTime.Now.Subtract(startTime);
                Console.WriteLine($"Epoch {epoch + 1:D2}/{max_epochs}, training_accuracy={1.0 - epoch_training_error:F3}, validation accuracy:{1 - epoch_validation_error:F3}, elapsed time={elapsedTime.TotalSeconds:F1} seconds");

                if (epoch_training_error < 0.001)
                {
                    break;
                }
            }

            return(new List <List <double> >()
            {
                training_accuracy, validation_accuracy
            });
        }