// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter. private static double CalculateAccuracy(LossMulticlassLogPerPixel anet, IEnumerable <ImageInfo> dataset) { var numRight = 0; var numWrong = 0; foreach (var imageInfo in dataset) { // Load the input image. using (var inputImage = Dlib.LoadImageAsMatrix <RgbPixel>(imageInfo.ImageFilename)) { // Load the ground-truth (RGB) labels.; using (var rgbLabelImage = Dlib.LoadImageAsMatrix <RgbPixel>(imageInfo.ClassLabelFilename)) { // Create predictions for each pixel. At this point, the type of each prediction // is an index (a value between 0 and 20). Note that the net may return an image // that is not exactly the same size as the input. using (var output = anet.Operator(inputImage)) using (var temp = output.First()) { // Convert the indexes to RGB values. using (var indexLabelImage = new Matrix <ushort>()) { PascalVOC2012.RgbLabelImageToIndexLabelImage(rgbLabelImage, indexLabelImage); // Crop the net output to be exactly the same size as the input. using (var chipDims = new ChipDims((uint)inputImage.Rows, (uint)inputImage.Columns)) using (var chipDetails = new ChipDetails(Dlib.CenteredRect(temp.Columns / 2, temp.Rows / 2, (uint)inputImage.Columns, (uint)inputImage.Rows), chipDims)) { using (var netOutput = Dlib.ExtractImageChip <ushort>(temp, chipDetails, InterpolationTypes.NearestNeighbor)) { var nr = indexLabelImage.Rows; var nc = indexLabelImage.Columns; // Compare the predicted values to the ground-truth values. for (var r = 0; r < nr; ++r) { for (var c = 0; c < nc; ++c) { var truth = indexLabelImage[r, c]; if (truth != LossMulticlassLogPerPixel.LabelToIgnore) { var prediction = netOutput[r, c]; if (prediction == truth) { ++numRight; } else { ++numWrong; } } } } } } } } } } } // Return the accuracy estimate. return(numRight / (double)(numRight + numWrong)); }
private static int Main(string[] args) { try { if (args.Length != 1 && args.Length != 2) { Console.WriteLine("To run this program you need a copy of the PASCAL VOC2012 dataset."); Console.WriteLine(); Console.WriteLine("You call this program like this: "); Console.WriteLine("./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]"); return(1); } // a mini-batch smaller than the default can be used with GPUs having less memory var minibatchSize = args.Length == 2 ? uint.Parse(args[1]) : 23u; Console.WriteLine($"mini-batch size: {minibatchSize}"); Console.WriteLine("\nSCANNING PASCAL VOC2012 DATASET\n"); var listing = PascalVOC2012.GetPascalVoc2012TrainListing(args[0]).ToArray(); Console.WriteLine($"images in dataset: {listing.Length}"); if (listing.Length == 0) { Console.WriteLine("Didn't find the VOC2012 dataset."); return(1); } const double initialLearningRate = 0.1; const double weightDecay = 0.0001; const double momentum = 0.9; using (var bnet = new LossMulticlassLogPerPixel(2)) using (var sgd = new Sgd((float)weightDecay, (float)momentum)) using (var trainer = new DnnTrainer <LossMulticlassLogPerPixel>(bnet, sgd)) { trainer.BeVerbose(); trainer.SetLearningRate(initialLearningRate); trainer.SetSynchronizationFile("pascal_voc2012_trainer_state_file.dat", 10 * 60); // This threshold is probably excessively large. trainer.SetIterationsWithoutProgressThreshold(5000); // Since the progress threshold is so large might as well set the batch normalization // stats window to something big too. Dlib.SetAllBnRunningStatsWindowSizes(bnet, 1000); // Output training parameters. Console.WriteLine(); Console.WriteLine(trainer); var samples = new List <Matrix <RgbPixel> >(); var labels = new List <Matrix <ushort> >(); //// Start a bunch of threads that read images from disk and pull out random crops. It's //// important to be sure to feed the GPU fast enough to keep it busy. Using multiple //// thread for this kind of data preparation helps us do that. Each thread puts the //// crops into the data queue. using (var data = new Pipe <TrainingSample>(200)) { var function = new Action <object>(seed => { using (var rnd = new Rand((ulong)seed)) { while (data.IsEnabled) { // Pick a random input image. var imageInfo = listing[rnd.GetRandom32BitNumber() % listing.Length]; // Load the input image. using (var inputImage = Dlib.LoadImageAsMatrix <RgbPixel>(imageInfo.ImageFilename)) { // Load the ground-truth (RGB) labels. using (var rgbLabelImage = Dlib.LoadImageAsMatrix <RgbPixel>(imageInfo.ClassLabelFilename)) { // Convert the indexes to RGB values. using (var indexLabelImage = new Matrix <ushort>()) { PascalVOC2012.RgbLabelImageToIndexLabelImage(rgbLabelImage, indexLabelImage); // Randomly pick a part of the image. var temp = new TrainingSample(); RandomlyCropImage(inputImage, indexLabelImage, temp, rnd); // Push the result to be used by the trainer. data.Enqueue(temp); } } } } } }); var threads = Enumerable.Range(1, 1).Select(i => { var dataLoader = new Thread(new ParameterizedThreadStart(function)) { Name = $"dataLoader{i}" }; dataLoader.Start((ulong)i); return(dataLoader); }).ToArray(); // The main training loop. Keep making mini-batches and giving them to the trainer. // We will run until the learning rate has dropped by a factor of 1e-4. while (trainer.GetLearningRate() >= 1e-4) { samples.DisposeElement(); labels.DisposeElement(); samples.Clear(); labels.Clear(); // make a mini-batch while (samples.Count < minibatchSize) { data.Dequeue(out var temp); samples.Add(temp.InputImage); labels.Add(temp.LabelImage); temp.Dispose(); } LossMulticlassLogPerPixel.TrainOneStep(trainer, samples, labels); } // Training done, tell threads to stop and make sure to wait for them to finish before // moving on. data.Disable(); foreach (var thread in threads) { thread.Join(); } // also wait for threaded processing to stop in the trainer. trainer.GetNet(); bnet.Clean(); Console.WriteLine("saving network"); LossMulticlassLogPerPixel.Serialize(bnet, SemanticSegmentationNetFilename); } // Make a copy of the network to use it for inference. using (var anet = bnet.CloneAs(3)) { Console.WriteLine("Testing the network..."); // Find the accuracy of the newly trained network on both the training and the validation sets. Console.WriteLine($"train accuracy : {CalculateAccuracy(anet, PascalVOC2012.GetPascalVoc2012TrainListing(args[0]))}"); Console.WriteLine($"val accuracy : {CalculateAccuracy(anet, PascalVOC2012.GetPascalVoc2012ValListing(args[0]))}"); } } } catch (Exception e) { Console.WriteLine(e); return(1); } return(0); }
private static void Main(string[] args) { if (args.Length != 1) { Console.WriteLine("You call this program like this: "); Console.WriteLine("./dnn_semantic_segmentation_train_ex /path/to/images"); Console.WriteLine(); Console.WriteLine("You will also need a trained 'semantic_segmentation_voc2012net.dnn' file."); Console.WriteLine("You can either train it yourself (see example program"); Console.WriteLine("dnn_semantic_segmentation_train_ex), or download a"); Console.WriteLine("copy from here: http://dlib.net/files/semantic_segmentation_voc2012net.dnn"); return; } try { // Read the file containing the trained network from the working directory. using (var net = LossMulticlassLogPerPixel.Deserialize("semantic_segmentation_voc2012net.dnn")) { // Show inference results in a window. using (var win = new ImageWindow()) { // Find supported image files. var files = Directory.GetFiles(args[0]) .Where(s => s.EndsWith(".jpeg") || s.EndsWith(".jpg") || s.EndsWith(".png")).ToArray(); Console.WriteLine($"Found {files.Length} images, processing..."); foreach (var file in files) { // Load the input image. using (var inputImage = Dlib.LoadImageAsMatrix <RgbPixel>(file)) { // Create predictions for each pixel. At this point, the type of each prediction // is an index (a value between 0 and 20). Note that the net may return an image // that is not exactly the same size as the input. using (var output = net.Operator(inputImage)) using (var temp = output.First()) { // Crop the returned image to be exactly the same size as the input. var rect = Rectangle.CenteredRect((int)(temp.Columns / 2d), (int)(temp.Rows / 2d), (uint)inputImage.Columns, (uint)inputImage.Rows); using (var dims = new ChipDims((uint)inputImage.Rows, (uint)inputImage.Columns)) using (var chipDetails = new ChipDetails(rect, dims)) using (var indexLabelImage = Dlib.ExtractImageChip <ushort>(temp, chipDetails, InterpolationTypes.NearestNeighbor)) { // Convert the indexes to RGB values. using (var rgbLabelImage = IndexLabelImageToRgbLabelImage(indexLabelImage)) { // Show the input image on the left, and the predicted RGB labels on the right. using (var joinedRow = Dlib.JoinRows(inputImage, rgbLabelImage)) { win.SetImage(joinedRow); // Find the most prominent class label from amongst the per-pixel predictions. var classLabel = GetMostProminentNonBackgroundClassLabel(indexLabelImage); Console.WriteLine($"{file} : {classLabel} - hit enter to process the next image"); Console.ReadKey(); } } } } } } } } } catch (Exception e) { Console.WriteLine(e); } }