Beispiel #1
0
        private static int Main(string[] args)
        {
            try
            {
                if (args.Length != 1)
                {
                    Console.WriteLine("To run this program you need a copy of the PASCAL VOC2012 dataset.");
                    Console.WriteLine();
                    Console.WriteLine("You call this program like this: ");
                    Console.WriteLine("./dnn_semantic_segmentation_train_ex /path/to/VOC2012");
                    return(1);
                }

                Console.WriteLine("\nSCANNING PASCAL VOC2012 DATASET\n");

                var listing = PascalVOC2012.GetPascalVoc2012TrainListing(args[0]).ToArray();
                Console.WriteLine($"images in dataset: {listing.Length}");
                if (listing.Length == 0)
                {
                    Console.WriteLine("Didn't find the VOC2012 dataset.");
                    return(1);
                }

                const double initialLearningRate = 0.1;
                const double weightDecay         = 0.0001;
                const double momentum            = 0.9;

                using (var net = new LossMulticlassLogPerPixel(0))
                    using (var sgd = new Sgd((float)weightDecay, (float)momentum))
                        using (var trainer = new DnnTrainer <LossMulticlassLogPerPixel>(net, sgd))
                        {
                            trainer.BeVerbose();
                            trainer.SetLearningRate(initialLearningRate);
                            trainer.SetSynchronizationFile("pascal_voc2012_trainer_state_file.dat", 10 * 60);
                            // This threshold is probably excessively large.
                            trainer.SetIterationsWithoutProgressThreshold(5000);
                            // Since the progress threshold is so large might as well set the batch normalization
                            // stats window to something big too.
                            Dlib.SetAllBnRunningStatsWindowSizes(net, 1000);

                            // Output training parameters.
                            Console.WriteLine();
                            Console.WriteLine(trainer);

                            var samples = new List <Matrix <RgbPixel> >();
                            var labels  = new List <Matrix <ushort> >();

                            //// Start a bunch of threads that read images from disk and pull out random crops.  It's
                            //// important to be sure to feed the GPU fast enough to keep it busy.  Using multiple
                            //// thread for this kind of data preparation helps us do that.  Each thread puts the
                            //// crops into the data queue.
                            using (var data = new Pipe <TrainingSample>(200))
                            {
                                var function = new Action <object>(seed =>
                                {
                                    using (var rnd = new Rand((ulong)seed))
                                    {
                                        while (data.IsEnabled)
                                        {
                                            // Pick a random input image.
                                            var imageInfo = listing[rnd.GetRandom32BitNumber() % listing.Length];

                                            // Load the input image.
                                            using (var inputImage = Dlib.LoadImageAsMatrix <RgbPixel>(imageInfo.ImageFilename))
                                            {
                                                // Load the ground-truth (RGB) labels.
                                                using (var rgbLabelImage = Dlib.LoadImageAsMatrix <RgbPixel>(imageInfo.ClassLabelFilename))
                                                {
                                                    // Convert the indexes to RGB values.
                                                    using (var indexLabelImage = new Matrix <ushort>())
                                                    {
                                                        PascalVOC2012.RgbLabelImageToIndexLabelImage(rgbLabelImage, indexLabelImage);

                                                        // Randomly pick a part of the image.
                                                        var temp = new TrainingSample();
                                                        RandomlyCropImage(inputImage, indexLabelImage, temp, rnd);

                                                        // Push the result to be used by the trainer.
                                                        data.Enqueue(temp);
                                                    }
                                                }
                                            }
                                        }
                                    }
                                });

                                var threads = Enumerable.Range(1, 1).Select(i =>
                                {
                                    var dataLoader = new Thread(new ParameterizedThreadStart(function))
                                    {
                                        Name = $"dataLoader{i}"
                                    };
                                    dataLoader.Start((ulong)i);
                                    return(dataLoader);
                                }).ToArray();

                                // The main training loop.  Keep making mini-batches and giving them to the trainer.
                                // We will run until the learning rate has dropped by a factor of 1e-4.
                                while (trainer.GetLearningRate() >= 1e-4)
                                {
                                    samples.DisposeElement();
                                    labels.DisposeElement();
                                    samples.Clear();
                                    labels.Clear();

                                    // make a 30-image mini-batch
                                    while (samples.Count < 30)
                                    {
                                        data.Dequeue(out var temp);

                                        samples.Add(temp.InputImage);
                                        labels.Add(temp.LabelImage);

                                        temp.Dispose();
                                    }

                                    LossMulticlassLogPerPixel.TrainOneStep(trainer, samples, labels);
                                }

                                // Training done, tell threads to stop and make sure to wait for them to finish before
                                // moving on.
                                data.Disable();
                                foreach (var thread in threads)
                                {
                                    thread.Join();
                                }

                                // also wait for threaded processing to stop in the trainer.
                                trainer.GetNet();

                                net.Clean();
                                Console.WriteLine("saving network");
                                LossMulticlassLogPerPixel.Serialize(net, "semantic_segmentation_voc2012net.dnn");
                            }

                            // Make a copy of the network to use it for inference.
                            using (var anet = net.CloneAs(1))
                            {
                                Console.WriteLine("Testing the network...");

                                // Find the accuracy of the newly trained network on both the training and the validation sets.
                                Console.WriteLine($"train accuracy  :  {CalculateAccuracy(anet, PascalVOC2012.GetPascalVoc2012TrainListing(args[0]))}");
                                Console.WriteLine($"val accuracy    :  {CalculateAccuracy(anet, PascalVOC2012.GetPascalVoc2012ValListing(args[0]))}");
                            }
                        }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
                return(1);
            }

            return(0);
        }