static void Main(string[] args) { // unzip archive if (!Directory.Exists("cat")) { Console.WriteLine("Unpacking data...."); DataUtil.Unzip(@"..\..\..\..\..\catsanddogs.zip", "."); } // create mapping files if (!File.Exists("train_map.txt")) { Console.WriteLine("Creating mapping files..."); CreateMapFiles(); } // get a training and validation image reader var trainingReader = DataUtil.GetImageReader("train_map.txt", imageWidth, imageHeight, numChannels, 2, randomizeData: true, augmentData: true); var validationReader = DataUtil.GetImageReader("validation_map.txt", imageWidth, imageHeight, numChannels, 2, randomizeData: false, augmentData: false); // build features and labels var features = NetUtil.Var(new int[] { imageHeight, imageWidth, numChannels }, DataType.Float); var labels = NetUtil.Var(new int[] { 2 }, DataType.Float); // ****************** // ADD YOUR CODE HERE // ****************** CNTK.Function network = null; // fix this line! // print the network to the console Console.WriteLine("Neural Network architecture: "); Console.WriteLine(network.ToSummary()); // set up the loss function and the classification error function var lossFunc = CNTKLib.CrossEntropyWithSoftmax(network.Output, labels); var errorFunc = CNTKLib.ClassificationError(network.Output, labels); // use the Adam learning algorithm var learner = network.GetAdamLearner( learningRateSchedule: (0.0001, 1), momentumSchedule: (0.99, 1)); // set up a trainer and an evaluator var trainer = network.GetTrainer(learner, lossFunc, errorFunc); var evaluator = network.GetEvaluator(errorFunc); // declare some variables var result = 0.0; var sampleCount = 0; var batchCount = 0; var lines = new List <List <double> >() { new List <double>(), new List <double>() }; // train the network during several epochs Console.WriteLine("Training the neural network...."); for (int epoch = 0; epoch < maxEpochs; epoch++) { Console.Write($"[{DateTime.Now:HH:mm:ss}] Training epoch {epoch+1}/{maxEpochs}... "); // train the network using random batches result = 0.0; sampleCount = 0; batchCount = 0; while (sampleCount < 2 * trainingSetSize) { // get the current batch var batch = trainingReader.GetBatch(batchSize); var featuresBatch = batch[trainingReader.StreamInfo("features")]; var labelsBatch = batch[trainingReader.StreamInfo("labels")]; // train the network on the batch var(Loss, Evaluation) = trainer.TrainBatch( new[] {
/// <summary> /// Train the model. /// </summary> /// <param name="threshold"></param> public void Train(double threshold = 0) { // create model and variables features = CreateFeatureVariable(); labels = CreateLabelVariable(); Model = CreateModel(features); AssertSequenceLength(); // set up loss function CNTK.Function lossFunction = null; switch (lossFunctionType) { case LossFunctionType.BinaryCrossEntropy: lossFunction = CNTK.CNTKLib.BinaryCrossEntropy(Model, labels); break; case LossFunctionType.MSE: lossFunction = CNTK.CNTKLib.SquaredError(Model, labels); break; case LossFunctionType.CrossEntropyWithSoftmax: lossFunction = CNTK.CNTKLib.CrossEntropyWithSoftmax(Model, labels); break; case LossFunctionType.Custom: lossFunction = CustomLossFunction(); break; } // set up accuracy function CNTK.Function accuracy_function = null; switch (accuracyFunctionType) { case AccuracyFunctionType.SameAsLoss: accuracy_function = lossFunction; break; case AccuracyFunctionType.BinaryAccuracy: accuracy_function = NetUtil.BinaryAccuracy(Model, labels); break; } // set up an adam learner var learner = Model.GetAdamLearner( (LearningRate, (uint)BatchSize), // remove batch_size? (0.9, (uint)BatchSize), // remove batch_size? unitGain: false); // set up trainer trainer = CNTK.CNTKLib.CreateTrainer(Model, lossFunction, accuracy_function, new CNTK.LearnerVector() { learner }); // set up a scheduler to tweak the learning rate scheduler = new ReduceLROnPlateau(learner, LearningRate); // set up an evaluator if (validationFeatures != null) { evaluator = CNTK.CNTKLib.CreateEvaluator(accuracy_function); } // write the model summary Console.WriteLine(" Model architecture:"); Console.WriteLine(Model.ToSummary()); // clear the training curves TrainingCurves[0].Clear(); TrainingCurves[1].Clear(); // train for a certain number of epochs for (int epoch = 0; epoch < NumberOfEpochs; epoch++) { var epoch_start_time = DateTime.Now; // train and evaluate the model var epoch_training_metric = TrainBatches(); var epoch_validation_accuracy = EvaluateBatches(); // add to training curve TrainingCurves[0].Add(epoch_training_metric); TrainingCurves[1].Add(epoch_validation_accuracy); // write current loss and accuracy var elapsedTime = DateTime.Now.Subtract(epoch_start_time); if (metricType == MetricType.Accuracy) { Console.WriteLine($"Epoch {epoch + 1:D2}/{NumberOfEpochs}, Elapsed time: {elapsedTime.TotalSeconds:F3} seconds. " + $"Training Accuracy: {epoch_training_metric:F3}. Validation Accuracy: {epoch_validation_accuracy:F3}."); } else { Console.WriteLine($"Epoch {epoch + 1:D2}/{NumberOfEpochs}, Elapsed time: {elapsedTime.TotalSeconds:F3} seconds, Training Loss: {epoch_training_metric:F3}"); } // abort training if scheduler says so if (scheduler.Update(epoch_training_metric)) { break; } if ((threshold != 0) && (epoch_training_metric < threshold)) { break; } } }