Ejemplo n.º 1
0
        void create_network()
        {
            Console.WriteLine("Compute Device: " + computeDevice.AsString());
            imageVariable       = Util.inputVariable(new int[] { 28, 28, 1 }, "image_tensor");
            categoricalVariable = Util.inputVariable(new int[] { 10 }, "label_tensor");

            network = imageVariable;
            network = Layers.Convolution2D(network, 32, new int[] { 3, 3 }, computeDevice, CC.ReLU);
            network = CC.Pooling(network, C.PoolingType.Max, new int[] { 2, 2 }, new int[] { 2 });
            network = Layers.Convolution2D(network, 64, new int[] { 3, 3 }, computeDevice, CC.ReLU);
            network = CC.Pooling(network, C.PoolingType.Max, new int[] { 2, 2 }, new int[] { 2 });
            network = Layers.Convolution2D(network, 64, new int[] { 3, 3 }, computeDevice, CC.ReLU);
            network = Layers.Dense(network, 64, computeDevice, activation: CC.ReLU);
            network = Layers.Dense(network, 10, computeDevice);

            Logging.detailed_summary(network);
            Logging.log_number_of_parameters(network);

            loss_function = CC.CrossEntropyWithSoftmax(network, categoricalVariable);
            eval_function = CC.ClassificationError(network, categoricalVariable);

            learner = CC.AdamLearner(
                new C.ParameterVector(network.Parameters().ToArray()),
                new C.TrainingParameterScheduleDouble(0.001 * batch_size, (uint)batch_size),
                new C.TrainingParameterScheduleDouble(0.9),
                true,
                new C.TrainingParameterScheduleDouble(0.99));

            trainer   = CC.CreateTrainer(network, loss_function, eval_function, new C.LearnerVector(new C.Learner[] { learner }));
            evaluator = CC.CreateEvaluator(eval_function);
        }
        /// <summary>
        /// Create the neural network for this app.
        /// </summary>
        /// <returns>The neural network to use</returns>
        public static CNTK.Function CreateNetwork()
        {
            // build features and labels
            features = NetUtil.Var(new int[] { 13 }, DataType.Float);
            labels   = NetUtil.Var(new int[] { 1 }, DataType.Float);

            // build the network
            var network = features
                          .Dense(64, CNTKLib.ReLU)
                          .Dense(64, CNTKLib.ReLU)
                          .Dense(1)
                          .ToNetwork();

            // set up the loss function and the classification error function
            var lossFunc  = NetUtil.MeanSquaredError(network.Output, labels);
            var errorFunc = NetUtil.MeanAbsoluteError(network.Output, labels);

            // use the Adam learning algorithm
            var learner = network.GetAdamLearner(
                learningRateSchedule: (0.001, 1),
                momentumSchedule: (0.9, 1),
                unitGain: true);

            // set up a trainer and an evaluator
            trainer   = network.GetTrainer(learner, lossFunc, errorFunc);
            evaluator = network.GetEvaluator(errorFunc);

            return(network);
        }
Ejemplo n.º 3
0
        /// <summary>
        /// Train the model.
        /// </summary>
        /// <param name="threshold"></param>
        public void Train(double threshold = 0)
        {
            // create model and variables
            features = CreateFeatureVariable();
            labels   = CreateLabelVariable();
            Model    = CreateModel(features);
            AssertSequenceLength();

            // set up loss function
            CNTK.Function lossFunction = null;
            switch (lossFunctionType)
            {
            case LossFunctionType.BinaryCrossEntropy: lossFunction = CNTK.CNTKLib.BinaryCrossEntropy(Model, labels); break;

            case LossFunctionType.MSE: lossFunction = CNTK.CNTKLib.SquaredError(Model, labels); break;

            case LossFunctionType.CrossEntropyWithSoftmax: lossFunction = CNTK.CNTKLib.CrossEntropyWithSoftmax(Model, labels); break;

            case LossFunctionType.Custom: lossFunction = CustomLossFunction(); break;
            }

            // set up accuracy function
            CNTK.Function accuracy_function = null;
            switch (accuracyFunctionType)
            {
            case AccuracyFunctionType.SameAsLoss: accuracy_function = lossFunction; break;

            case AccuracyFunctionType.BinaryAccuracy: accuracy_function = NetUtil.BinaryAccuracy(Model, labels); break;
            }

            // set up an adam learner
            var learner = Model.GetAdamLearner(
                (LearningRate, (uint)BatchSize), // remove batch_size?
                (0.9, (uint)BatchSize),          // remove batch_size?
                unitGain: false);

            // set up trainer
            trainer = CNTK.CNTKLib.CreateTrainer(Model, lossFunction, accuracy_function, new CNTK.LearnerVector()
            {
                learner
            });

            // set up a scheduler to tweak the learning rate
            scheduler = new ReduceLROnPlateau(learner, LearningRate);

            // set up an evaluator
            if (validationFeatures != null)
            {
                evaluator = CNTK.CNTKLib.CreateEvaluator(accuracy_function);
            }

            // write the model summary
            Console.WriteLine("  Model architecture:");
            Console.WriteLine(Model.ToSummary());

            // clear the training curves
            TrainingCurves[0].Clear();
            TrainingCurves[1].Clear();

            // train for a certain number of epochs
            for (int epoch = 0; epoch < NumberOfEpochs; epoch++)
            {
                var epoch_start_time = DateTime.Now;

                // train and evaluate the model
                var epoch_training_metric     = TrainBatches();
                var epoch_validation_accuracy = EvaluateBatches();

                // add to training curve
                TrainingCurves[0].Add(epoch_training_metric);
                TrainingCurves[1].Add(epoch_validation_accuracy);

                // write current loss and accuracy
                var elapsedTime = DateTime.Now.Subtract(epoch_start_time);
                if (metricType == MetricType.Accuracy)
                {
                    Console.WriteLine($"Epoch {epoch + 1:D2}/{NumberOfEpochs}, Elapsed time: {elapsedTime.TotalSeconds:F3} seconds. " +
                                      $"Training Accuracy: {epoch_training_metric:F3}. Validation Accuracy: {epoch_validation_accuracy:F3}.");
                }
                else
                {
                    Console.WriteLine($"Epoch {epoch + 1:D2}/{NumberOfEpochs}, Elapsed time: {elapsedTime.TotalSeconds:F3} seconds, Training Loss: {epoch_training_metric:F3}");
                }

                // abort training if scheduler says so
                if (scheduler.Update(epoch_training_metric))
                {
                    break;
                }
                if ((threshold != 0) && (epoch_training_metric < threshold))
                {
                    break;
                }
            }
        }