Esempio n. 1
0
        private static void UpdateBestModelFile(LossMulticlassLog net, double accuracy, string output, string basename, string postfix)
        {
            bool save;
            var  candidates = new Dictionary <string, double>();

            var files = Directory.GetFiles(output, $"{basename}_{postfix}_best_*.tmp").ToArray();

            if (files.Any())
            {
                var culture = new CultureInfo("en-US");
                foreach (var file in files)
                {
                    var value = Path.GetFileNameWithoutExtension(file).Replace($"{basename}_{postfix}_best_", "");
                    if (double.TryParse(value, NumberStyles.Float, culture, out var tmp))
                    {
                        candidates.Add(file, tmp);
                    }
                }

                // there is no last best file or latest accuracy gets over old one
                save = !candidates.Any() || candidates.All(pair => pair.Value < accuracy);
            }
            else
            {
                save = true;
            }

            if (save)
            {
                var path = Path.Combine(output, $"{basename}_{postfix}_best_{accuracy}.tmp");
                LossMulticlassLog.Serialize(net, path);
                Logger.Info($"Best Accuracy Model file is saved for {postfix} [{accuracy}]");

                // delete old files
                foreach (var(key, _) in candidates)
                {
                    try
                    {
                        File.Delete(key);
                    }
                    catch (Exception e)
                    {
                        Logger.Error($"Failed to delete '{key}'. Reason: {e.Message}");
                    }
                }
            }
        }
Esempio n. 2
0
        private void Test(Parameter parameter)
        {
            try
            {
                IList <Matrix <C> > trainingImages;
                IList <T>           trainingLabels;
                IList <Matrix <C> > testingImages;
                IList <T>           testingLabels;

                Logger.Info("Start load train images");
                Load(parameter.Dataset, "train", out trainingImages, out trainingLabels);
                Logger.Info($"Load train images: {trainingImages.Count}");

                Logger.Info("Start load test images");
                Load(parameter.Dataset, "test", out testingImages, out testingLabels);
                Logger.Info($"Load test images: {testingImages.Count}");
                Logger.Info("");

                // So with that out of the way, we can make a network instance.
                var networkId = SetupNetwork();

                using (var net = LossMulticlassLog.Deserialize(parameter.Model, networkId))
                {
                    this.SetEvalMode(networkId, net);
                    var validationParameter = new ValidationParameter <T, C>
                    {
                        BaseName       = Path.GetFileNameWithoutExtension(parameter.Model),
                        Trainer        = net,
                        TrainingImages = trainingImages,
                        TrainingLabels = trainingLabels,
                        TestingImages  = testingImages,
                        TestingLabels  = testingLabels,
                        UseConsole     = true,
                        SaveToXml      = false,
                        OutputDiffLog  = true,
                        Output         = Path.GetDirectoryName(parameter.Model)
                    };

                    Validation(validationParameter, out _, out _);
                }
            }
            catch (Exception e)
            {
                Logger.Error(e.Message);
            }
        }
Esempio n. 3
0
        public static void SetAllBnRunningStatsWindowSizes(LossMulticlassLog net, uint newWindowSize)
        {
            if (net == null)
            {
                throw new ArgumentNullException(nameof(net));
            }

            net.ThrowIfDisposed();

            var ret = NativeMethods.set_all_bn_running_stats_window_sizes_loss_multiclass_log(net.NativePtr,
                                                                                              net.NetworkType,
                                                                                              newWindowSize);

            if (ret == NativeMethods.ErrorType.DnnNotSupportNetworkType)
            {
                throw new NotSupportNetworkTypeException(net.NetworkType);
            }
        }
Esempio n. 4
0
        /// <summary>
        /// Initializes a new instance of the <see cref="SimpleAgeEstimator"/> class with the model file path that this estimator uses.
        /// </summary>
        /// <param name="modelPath">The model file path that this estimator uses.</param>
        /// <exception cref="FileNotFoundException">The model file is not found.</exception>
        public SimpleAgeEstimator(string modelPath)
        {
            if (!File.Exists(modelPath))
            {
                throw new FileNotFoundException(modelPath);
            }

            var ret       = NativeMethods.LossMulticlassLog_age_train_type_create();
            var networkId = LossMulticlassLogRegistry.GetId(ret);

            if (LossMulticlassLogRegistry.Contains(networkId))
            {
                NativeMethods.LossMulticlassLog_age_train_type_delete(ret);
            }
            else
            {
                LossMulticlassLogRegistry.Add(ret);
            }

            this._Network = LossMulticlassLog.Deserialize(modelPath, networkId);
        }
Esempio n. 5
0
 protected override void SetEvalMode(int networkId, LossMulticlassLog net)
 {
     NativeMethods.LossMulticlassLog_emotion_train_type_eval(networkId, net.NativePtr);
 }
Esempio n. 6
0
        private static void Validation(string baseName,
                                       LossMulticlassLog net,
                                       IList <Matrix <RgbPixel> > trainingImages,
                                       IList <uint> trainingLabels,
                                       IList <Matrix <RgbPixel> > testingImages,
                                       IList <uint> testingLabels,
                                       bool useConsole,
                                       bool saveToXml,
                                       out double trainAccuracy,
                                       out double testAccuracy)
        {
            trainAccuracy = 0;
            testAccuracy  = 0;

            using (var predictedLabels = net.Operator(trainingImages))
            {
                var numRight = 0;
                var numWrong = 0;

                // And then let's see if it classified them correctly.
                for (var i = 0; i < trainingImages.Count; ++i)
                {
                    if (predictedLabels[i] == trainingLabels[i])
                    {
                        ++numRight;
                    }
                    else
                    {
                        ++numWrong;
                    }
                }

                if (useConsole)
                {
                    Console.WriteLine($"training num_right: {numRight}");
                    Console.WriteLine($"training num_wrong: {numWrong}");
                    Console.WriteLine($"training accuracy:  {numRight / (double)(numRight + numWrong)}");
                }

                trainAccuracy = numRight / (double)(numRight + numWrong);

                using (var predictedLabels2 = net.Operator(testingImages))
                {
                    numRight = 0;
                    numWrong = 0;
                    for (var i = 0; i < testingImages.Count; ++i)
                    {
                        if (predictedLabels2[i] == testingLabels[i])
                        {
                            ++numRight;
                        }
                        else
                        {
                            ++numWrong;
                        }
                    }

                    if (useConsole)
                    {
                        Console.WriteLine($"testing num_right: {numRight}");
                        Console.WriteLine($"testing num_wrong: {numWrong}");
                        Console.WriteLine($"testing accuracy:  {numRight / (double)(numRight + numWrong)}");
                    }

                    testAccuracy = numRight / (double)(numRight + numWrong);

                    // Finally, you can also save network parameters to XML files if you want to do
                    // something with the network in another tool.  For example, you could use dlib's
                    // tools/convert_dlib_nets_to_caffe to convert the network to a caffe model.
                    if (saveToXml)
                    {
                        Dlib.NetToXml(net, $"{baseName}.xml");
                    }
                }
            }
        }
Esempio n. 7
0
        private static void Train(string baseName, string dataset, uint epoch, double learningRate, double minLearningRate, uint miniBatchSize, uint validation, bool useMean)
        {
            try
            {
                IList <Matrix <RgbPixel> > trainingImages;
                IList <uint> trainingLabels;
                IList <Matrix <RgbPixel> > testingImages;
                IList <uint> testingLabels;

                var mean = useMean ? Path.Combine(dataset, "train.mean.bmp") : null;

                Console.WriteLine("Start load train images");
                Load("train", dataset, mean, out trainingImages, out trainingLabels);
                Console.WriteLine($"Load train images: {trainingImages.Count}");

                Console.WriteLine("Start load test images");
                Load("test", dataset, mean, out testingImages, out testingLabels);
                Console.WriteLine($"Load test images: {testingImages.Count}");

                // So with that out of the way, we can make a network instance.
                var trainNet  = NativeMethods.LossMulticlassLog_age_train_type_create();
                var networkId = LossMulticlassLogRegistry.GetId(trainNet);
                LossMulticlassLogRegistry.Add(trainNet);

                using (var net = new LossMulticlassLog(networkId))
                    using (var trainer = new DnnTrainer <LossMulticlassLog>(net))
                    {
                        trainer.SetLearningRate(learningRate);
                        trainer.SetMinLearningRate(minLearningRate);
                        trainer.SetMiniBatchSize(miniBatchSize);
                        trainer.BeVerbose();
                        trainer.SetSynchronizationFile(baseName, 180);

                        // create array box
                        var trainingImagesCount = trainingImages.Count;
                        var trainingLabelsCount = trainingLabels.Count;

                        var maxIteration = (int)Math.Ceiling(trainingImagesCount / (float)miniBatchSize);
                        var imageBatches = new Matrix <RgbPixel> [maxIteration][];
                        var labelBatches = new uint[maxIteration][];
                        for (var i = 0; i < maxIteration; i++)
                        {
                            if (miniBatchSize <= trainingImagesCount - i * miniBatchSize)
                            {
                                imageBatches[i] = new Matrix <RgbPixel> [miniBatchSize];
                                labelBatches[i] = new uint[miniBatchSize];
                            }
                            else
                            {
                                imageBatches[i] = new Matrix <RgbPixel> [trainingImagesCount % miniBatchSize];
                                labelBatches[i] = new uint[trainingLabelsCount % miniBatchSize];
                            }
                        }

                        using (var fs = new FileStream($"{baseName}.log", FileMode.Create, FileAccess.Write, FileShare.Write))
                            using (var sw = new StreamWriter(fs, Encoding.UTF8))
                                for (var e = 0; e < epoch; e++)
                                {
                                    var randomArray = Enumerable.Range(0, trainingImagesCount).OrderBy(i => Guid.NewGuid()).ToArray();
                                    var index       = 0;
                                    for (var i = 0; i < imageBatches.Length; i++)
                                    {
                                        var currentImages = imageBatches[i];
                                        var currentLabels = labelBatches[i];
                                        for (var j = 0; j < imageBatches[i].Length; j++)
                                        {
                                            var rIndex = randomArray[index];
                                            currentImages[j] = trainingImages[rIndex];
                                            currentLabels[j] = trainingLabels[rIndex];
                                            index++;
                                        }
                                    }

                                    for (var i = 0; i < maxIteration; i++)
                                    {
                                        LossMulticlassLog.TrainOneStep(trainer, imageBatches[i], labelBatches[i]);
                                    }

                                    var lr   = trainer.GetLearningRate();
                                    var loss = trainer.GetAverageLoss();

                                    var trainLog = $"Epoch: {e}, learning Rate: {lr}, average loss: {loss}";
                                    Console.WriteLine(trainLog);
                                    sw.WriteLine(trainLog);

                                    if (e > 0 && e % validation == 0)
                                    {
                                        Validation(baseName, net, trainingImages, trainingLabels, testingImages, testingLabels, false, false, out var trainAccuracy, out var testAccuracy);

                                        var validationLog = $"Epoch: {e}, train accuracy: {trainAccuracy}, test accuracy: {testAccuracy}";
                                        Console.WriteLine(validationLog);
                                        sw.WriteLine(validationLog);
                                    }

                                    if (lr < minLearningRate)
                                    {
                                        break;
                                    }
                                }

                        // wait for training threads to stop
                        trainer.GetNet();
                        Console.WriteLine("done training");

                        net.Clean();
                        LossMulticlassLog.Serialize(net, $"{baseName}.dat");

                        // Now let's run the training images through the network.  This statement runs all the
                        // images through it and asks the loss layer to convert the network's raw output into
                        // labels.  In our case, these labels are the numbers between 0 and 9.
                        Validation(baseName, net, trainingImages, trainingLabels, testingImages, testingLabels, true, true, out _, out _);
                    }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
            }
        }
Esempio n. 8
0
        public int Start(string[] args)
        {
            var app = new CommandLineApplication(false);

            app.Name        = this._Name;
            app.Description = this._Description;
            app.HelpOption("-h|--help");

            app.Command("clean", command =>
            {
                var outputOption = command.Option("-o|--output", "The output directory path.", CommandOptionType.SingleValue);

                command.OnExecute(() =>
                {
                    if (!outputOption.HasValue() || !Directory.Exists(outputOption.Value()))
                    {
                        Logger.Error($"'{outputOption.Value()} is missing or output option is not specified");
                        return(-1);
                    }

                    Logger.Info($"             Output: {outputOption.Value()}");
                    Logger.Info("");

                    Clean(outputOption.Value());

                    return(0);
                });
            });

            app.Command("train", command =>
            {
                const uint epochDefault             = 300;
                const double learningRateDefault    = 0.001d;
                const double minLearningRateDefault = 0.00001d;
                const uint minBatchSizeDefault      = 256;
                const uint validationDefault        = 30;

                var datasetOption         = command.Option("-d|--dataset", "The directory of dataset", CommandOptionType.SingleValue);
                var epochOption           = command.Option("-e|--epoch", $"The epoch. Default is {epochDefault}", CommandOptionType.SingleValue);
                var learningRateOption    = command.Option("-l|--lr", $"The learning rate. Default is {learningRateDefault}", CommandOptionType.SingleValue);
                var minLearningRateOption = command.Option("-m|--min-lr", $"The minimum learning rate. Default is {minLearningRateDefault}", CommandOptionType.SingleValue);
                var minBatchSizeOption    = command.Option("-b|--min-batchsize", $"The minimum batch size. Default is {minBatchSizeDefault}", CommandOptionType.SingleValue);
                var validationOption      = command.Option("-v|--validation-interval", $"The interval of validation. Default is {validationDefault}", CommandOptionType.SingleValue);
                var useMeanOption         = command.Option("-u|--use-mean", "Use mean image", CommandOptionType.NoValue);
                var outputOption          = command.Option("-o|--output", "The output directory path.", CommandOptionType.SingleValue);

                command.OnExecute(() =>
                {
                    var dataset = datasetOption.Value();
                    if (!datasetOption.HasValue() || !Directory.Exists(dataset))
                    {
                        Logger.Error("dataset does not exist");
                        return(-1);
                    }

                    var epoch = epochDefault;
                    if (epochOption.HasValue() && !uint.TryParse(epochOption.Value(), out epoch))
                    {
                        Logger.Error("epoch is invalid value");
                        return(-1);
                    }

                    var learningRate = learningRateDefault;
                    if (learningRateOption.HasValue() && !double.TryParse(learningRateOption.Value(), NumberStyles.Float, Thread.CurrentThread.CurrentCulture.NumberFormat, out learningRate))
                    {
                        Logger.Error("learning rate is invalid value");
                        return(-1);
                    }

                    var minLearningRate = minLearningRateDefault;
                    if (minLearningRateOption.HasValue() && !double.TryParse(minLearningRateOption.Value(), NumberStyles.Float, Thread.CurrentThread.CurrentCulture.NumberFormat, out minLearningRate))
                    {
                        Logger.Error("minimum learning rate is invalid value");
                        return(-1);
                    }

                    var minBatchSize = minBatchSizeDefault;
                    if (minBatchSizeOption.HasValue() && !uint.TryParse(minBatchSizeOption.Value(), out minBatchSize))
                    {
                        Logger.Error("minimum batch size is invalid value");
                        return(-1);
                    }

                    var validation = validationDefault;
                    if (validationOption.HasValue() && !uint.TryParse(validationOption.Value(), out validation) || validation == 0)
                    {
                        Logger.Error("validation interval is invalid value");
                        return(-1);
                    }

                    var output = "result";
                    if (outputOption.HasValue())
                    {
                        output = outputOption.Value();
                    }

                    Directory.CreateDirectory(output);

                    var useMean = useMeanOption.HasValue();

                    Logger.Info($"            Dataset: {dataset}");
                    Logger.Info($"              Epoch: {epoch}");
                    Logger.Info($"      Learning Rate: {learningRate}");
                    Logger.Info($"  Min Learning Rate: {minLearningRate}");
                    Logger.Info($"     Min Batch Size: {minBatchSize}");
                    Logger.Info($"Validation Interval: {validation}");
                    Logger.Info($"           Use Mean: {useMean}");
                    Logger.Info($"             Output: {output}");
                    Logger.Info("");

                    var name      = this.GetBaseName(epoch, learningRate, minLearningRate, minBatchSize);
                    var baseName  = Path.Combine(output, name);
                    var parameter = new Parameter
                    {
                        BaseName        = baseName,
                        Dataset         = dataset,
                        Output          = output,
                        Epoch           = epoch,
                        LearningRate    = learningRate,
                        MinLearningRate = minLearningRate,
                        MiniBatchSize   = minBatchSize,
                        Validation      = validation
                    };

                    Train(parameter);

                    return(0);
                });
            });

            app.Command("test", command =>
            {
                var datasetOption = command.Option("-d|--dataset", "The directory of dataset", CommandOptionType.SingleValue);
                var modelOption   = command.Option("-m|--model", "The model file path", CommandOptionType.SingleValue);

                command.OnExecute(() =>
                {
                    var dataset = datasetOption.Value();
                    if (!datasetOption.HasValue() || !Directory.Exists(dataset))
                    {
                        Logger.Error("dataset does not exist");
                        return(-1);
                    }

                    var model = modelOption.Value();
                    if (!modelOption.HasValue() || !File.Exists(model))
                    {
                        Logger.Error("model does not exist");
                        return(-1);
                    }

                    Logger.Info($"Dataset: {dataset}");
                    Logger.Info($"  Model: {model}");
                    Logger.Info("");

                    var parameter = new Parameter
                    {
                        Dataset = dataset,
                        Model   = model
                    };

                    Test(parameter);

                    return(0);
                });
            });

            app.Command("eval", command =>
            {
                var imageOption = command.Option("-i|--image", "The image file.", CommandOptionType.SingleValue);
                var modelOption = command.Option("-m|--model", "The model file path", CommandOptionType.SingleValue);

                command.OnExecute(() =>
                {
                    var image = imageOption.Value();
                    if (!imageOption.HasValue() || !File.Exists(image))
                    {
                        Logger.Error("image does not exist");
                        return(-1);
                    }

                    var model = modelOption.Value();
                    if (!modelOption.HasValue() || !File.Exists(model))
                    {
                        Logger.Error("model file does not exist");
                        return(-1);
                    }

                    Logger.Info($"Image File: {image}");
                    Logger.Info($"     Model: {model}");
                    Logger.Info("");

                    var networkId = SetupNetwork();

                    using (var net = LossMulticlassLog.Deserialize(model, networkId))
                        using (var fr = FaceRecognition.Create("Models"))
                            using (var img = FaceRecognition.LoadImageFile(image))
                            {
                                var location = fr.FaceLocations(img).FirstOrDefault();
                                if (location == null)
                                {
                                    Logger.Info("Missing face");
                                    return(-1);
                                }

                                var rect   = new Rectangle(location.Left, location.Top, location.Right, location.Bottom);
                                var dPoint = new[]
                                {
                                    new DPoint(rect.Left, rect.Top),
                                    new DPoint(rect.Right, rect.Top),
                                    new DPoint(rect.Left, rect.Bottom),
                                    new DPoint(rect.Right, rect.Bottom),
                                };
                                using (var tmp = Dlib.LoadImageAsMatrix <byte>(image))
                                {
                                    using (var face = Dlib.ExtractImage4Points(tmp, dPoint, this.Size, this.Size))
                                    {
                                        this.SetEvalMode(networkId, net);
                                        using (var predictedLabels = net.Operator(face))
                                            Logger.Info($"{this.Cast(predictedLabels[0])}");
                                    }
                                }
                            }

                    return(0);
                });
            });

            app.Command("demo", command =>
            {
                command.HelpOption("-?|-h|--help");
                var imageOption     = command.Option("-i|--image", "test image file", CommandOptionType.SingleValue);
                var modelOption     = command.Option("-m|--model", "model file", CommandOptionType.SingleValue);
                var directoryOption = command.Option("-d|--directory", "model files directory path", CommandOptionType.SingleValue);

                command.OnExecute(() =>
                {
                    if (!imageOption.HasValue())
                    {
                        Console.WriteLine("image option is missing");
                        app.ShowHelp();
                        return(-1);
                    }

                    if (!directoryOption.HasValue())
                    {
                        Console.WriteLine("directory option is missing");
                        app.ShowHelp();
                        return(-1);
                    }

                    if (!modelOption.HasValue())
                    {
                        Console.WriteLine("model option is missing");
                        app.ShowHelp();
                        return(-1);
                    }

                    var modelFile = modelOption.Value();
                    if (!File.Exists(modelFile))
                    {
                        Console.WriteLine($"'{modelFile}' is not found");
                        app.ShowHelp();
                        return(-1);
                    }

                    var imageFile = imageOption.Value();
                    if (!File.Exists(imageFile))
                    {
                        Console.WriteLine($"'{imageFile}' is not found");
                        app.ShowHelp();
                        return(-1);
                    }

                    var directory = directoryOption.Value();
                    if (!Directory.Exists(directory))
                    {
                        Console.WriteLine($"'{directory}' is not found");
                        app.ShowHelp();
                        return(-1);
                    }

                    using (var fr = FaceRecognition.Create(directory))
                        using (var image = FaceRecognition.LoadImageFile(imageFile))
                        {
                            var loc = fr.FaceLocations(image).FirstOrDefault();
                            if (loc == null)
                            {
                                Console.WriteLine("No face is detected");
                                return(0);
                            }

                            this.Demo(fr, modelFile, imageFile, image, loc);
                        }

                    return(0);
                });
            });

            return(app.Execute(args));
        }
Esempio n. 9
0
 protected virtual void SetEvalMode(int networkId, LossMulticlassLog net)
 {
 }
Esempio n. 10
0
        private void Train(Parameter parameter)
        {
            try
            {
                IList <Matrix <C> > trainingImages;
                IList <T>           trainingLabels;
                IList <Matrix <C> > testingImages;
                IList <T>           testingLabels;

                Logger.Info("Start load train images");
                Load(parameter.Dataset, "train", out trainingImages, out trainingLabels);
                Logger.Info($"Load train images: {trainingImages.Count}");

                Logger.Info("Start load test images");
                Load(parameter.Dataset, "test", out testingImages, out testingLabels);
                Logger.Info($"Load test images: {testingImages.Count}");
                Logger.Info("");

                // So with that out of the way, we can make a network instance.
                var networkId = SetupNetwork();

                using (var net = new LossMulticlassLog(networkId))
                    using (var solver = new Adam())
                        using (var trainer = new DnnTrainer <LossMulticlassLog>(net, solver))
                        {
                            var learningRate    = parameter.LearningRate;
                            var minLearningRate = parameter.MinLearningRate;
                            var miniBatchSize   = parameter.MiniBatchSize;
                            var baseName        = parameter.BaseName;
                            var epoch           = parameter.Epoch;
                            var validation      = parameter.Validation;

                            trainer.SetLearningRate(learningRate);
                            trainer.SetMinLearningRate(minLearningRate);
                            trainer.SetMiniBatchSize(miniBatchSize);
                            trainer.BeVerbose();
                            trainer.SetSynchronizationFile(baseName, 180);

                            // create array box
                            var trainingImagesCount = trainingImages.Count;
                            var trainingLabelsCount = trainingLabels.Count;

                            var maxIteration = (int)Math.Ceiling(trainingImagesCount / (float)miniBatchSize);
                            var imageBatches = new Matrix <C> [maxIteration][];
                            var labelBatches = new uint[maxIteration][];
                            for (var i = 0; i < maxIteration; i++)
                            {
                                if (miniBatchSize <= trainingImagesCount - i * miniBatchSize)
                                {
                                    imageBatches[i] = new Matrix <C> [miniBatchSize];
                                    labelBatches[i] = new uint[miniBatchSize];
                                }
                                else
                                {
                                    imageBatches[i] = new Matrix <C> [trainingImagesCount % miniBatchSize];
                                    labelBatches[i] = new uint[trainingLabelsCount % miniBatchSize];
                                }
                            }

                            using (var fs = new FileStream($"{baseName}.log", FileMode.Create, FileAccess.Write, FileShare.Write))
                                using (var sw = new StreamWriter(fs, Encoding.UTF8))
                                    for (var e = 0; e < epoch; e++)
                                    {
                                        var randomArray = Enumerable.Range(0, trainingImagesCount).OrderBy(i => Guid.NewGuid()).ToArray();
                                        var index       = 0;
                                        for (var i = 0; i < imageBatches.Length; i++)
                                        {
                                            var currentImages = imageBatches[i];
                                            var currentLabels = labelBatches[i];
                                            for (var j = 0; j < imageBatches[i].Length; j++)
                                            {
                                                var rIndex = randomArray[index];
                                                currentImages[j] = trainingImages[rIndex];
                                                currentLabels[j] = this.Cast(trainingLabels[rIndex]);
                                                index++;
                                            }
                                        }

                                        for (var i = 0; i < maxIteration; i++)
                                        {
                                            LossMulticlassLog.TrainOneStep(trainer, imageBatches[i], labelBatches[i]);
                                        }

                                        var lr   = trainer.GetLearningRate();
                                        var loss = trainer.GetAverageLoss();

                                        var trainLog = $"Epoch: {e}, learning Rate: {lr}, average loss: {loss}";
                                        Logger.Info(trainLog);
                                        sw.WriteLine(trainLog);

                                        if (e >= 0 && e % validation == 0)
                                        {
                                            var validationParameter = new ValidationParameter <T, C>
                                            {
                                                BaseName       = parameter.BaseName,
                                                Output         = parameter.Output,
                                                Trainer        = net,
                                                TrainingImages = trainingImages,
                                                TrainingLabels = trainingLabels,
                                                TestingImages  = testingImages,
                                                TestingLabels  = testingLabels,
                                                UseConsole     = true,
                                                SaveToXml      = true,
                                                OutputDiffLog  = true
                                            };

                                            Validation(validationParameter, out var trainAccuracy, out var testAccuracy);

                                            var validationLog = $"Epoch: {e}, train accuracy: {trainAccuracy}, test accuracy: {testAccuracy}";
                                            Logger.Info(validationLog);
                                            sw.WriteLine(validationLog);

                                            var name = this.GetBaseName(parameter.Epoch,
                                                                        parameter.LearningRate,
                                                                        parameter.MinLearningRate,
                                                                        parameter.MiniBatchSize);

                                            UpdateBestModelFile(net, testAccuracy, parameter.Output, name, "test");
                                            UpdateBestModelFile(net, trainAccuracy, parameter.Output, name, "train");
                                        }

                                        if (lr < minLearningRate)
                                        {
                                            Logger.Info($"Stop training: {lr} < {minLearningRate}");
                                            break;
                                        }
                                    }

                            // wait for training threads to stop
                            trainer.GetNet();
                            Logger.Info("done training");

                            net.Clean();
                            LossMulticlassLog.Serialize(net, $"{baseName}.tmp");

                            // Now let's run the training images through the network.  This statement runs all the
                            // images through it and asks the loss layer to convert the network's raw output into
                            // labels.  In our case, these labels are the numbers between 0 and 9.
                            var validationParameter2 = new ValidationParameter <T, C>
                            {
                                BaseName       = parameter.BaseName,
                                Output         = parameter.Output,
                                Trainer        = net,
                                TrainingImages = trainingImages,
                                TrainingLabels = trainingLabels,
                                TestingImages  = testingImages,
                                TestingLabels  = testingLabels,
                                UseConsole     = true,
                                SaveToXml      = true,
                                OutputDiffLog  = true
                            };

                            Validation(validationParameter2, out _, out _);

                            // clean up tmp files
                            Clean(parameter.Output);
                        }
            }
            catch (Exception e)
            {
                Logger.Error(e.Message);
            }
        }
Esempio n. 11
0
        private static int Main(string[] args)
        {
            if (args.Length != 1)
            {
                Console.WriteLine("This example needs the MNIST dataset to run!");
                Console.WriteLine("You can get MNIST from http://yann.lecun.com/exdb/mnist/");
                Console.WriteLine("Download the 4 files that comprise the dataset, decompress them, and");
                Console.WriteLine("put them in a folder.  Then give that folder as input to this program.");
                return(1);
            }

            try
            {
                // MNIST is broken into two parts, a training set of 60000 images and a test set of
                // 10000 images.  Each image is labeled so that we know what hand written digit is
                // depicted.  These next statements load the dataset into memory.
                IList <Matrix <byte> > trainingImages;
                IList <uint>           trainingLabels;
                IList <Matrix <byte> > testingImages;
                IList <uint>           testingLabels;
                Dlib.LoadMNISTDataset(args[0], out trainingImages, out trainingLabels, out testingImages, out testingLabels);


                // Now let's define the LeNet.  Broadly speaking, there are 3 parts to a network
                // definition.  The loss layer, a bunch of computational layers, and then an input
                // layer.  You can see these components in the network definition below.
                //
                // The input layer here says the network expects to be given matrix<unsigned char>
                // objects as input.  In general, you can use any dlib image or matrix type here, or
                // even define your own types by creating custom input layers.
                //
                // Then the middle layers define the computation the network will do to transform the
                // input into whatever we want.  Here we run the image through multiple convolutions,
                // ReLU units, max pooling operations, and then finally a fully connected layer that
                // converts the whole thing into just 10 numbers.
                //
                // Finally, the loss layer defines the relationship between the network outputs, our 10
                // numbers, and the labels in our dataset.  Since we selected loss_multiclass_log it
                // means we want to do multiclass classification with our network.   Moreover, the
                // number of network outputs (i.e. 10) is the number of possible labels.  Whichever
                // network output is largest is the predicted label.  So for example, if the first
                // network output is largest then the predicted digit is 0, if the last network output
                // is largest then the predicted digit is 9.

                // This net_type defines the entire network architecture.  For example, the block
                // relu<fc<84,SUBNET>> means we take the output from the subnetwork, pass it through a
                // fully connected layer with 84 outputs, then apply ReLU.  Similarly, a block of
                // max_pool<2,2,2,2,relu<con<16,5,5,1,1,SUBNET>>> means we apply 16 convolutions with a
                // 5x5 filter size and 1x1 stride to the output of a subnetwork, then apply ReLU, then
                // perform max pooling with a 2x2 window and 2x2 stride.



                // So with that out of the way, we can make a network instance.
                using (var net = new LossMulticlassLog(3))
                {
                    // And then train it using the MNIST data.  The code below uses mini-batch stochastic
                    // gradient descent with an initial learning rate of 0.01 to accomplish this.
                    using (var trainer = new DnnTrainer <LossMulticlassLog>(net))
                    {
                        trainer.SetLearningRate(0.01);
                        trainer.SetMinLearningRate(0.00001);
                        trainer.SetMiniBatchSize(128);
                        trainer.BeVerbose();
                        // Since DNN training can take a long time, we can ask the trainer to save its state to
                        // a file named "mnist_sync" every 20 seconds.  This way, if we kill this program and
                        // start it again it will begin where it left off rather than restarting the training
                        // from scratch.  This is because, when the program restarts, this call to
                        // set_synchronization_file() will automatically reload the settings from mnist_sync if
                        // the file exists.
                        trainer.SetSynchronizationFile("mnist_sync", 20);
                        // Finally, this line begins training.  By default, it runs SGD with our specified
                        // learning rate until the loss stops decreasing.  Then it reduces the learning rate by
                        // a factor of 10 and continues running until the loss stops decreasing again.  It will
                        // keep doing this until the learning rate has dropped below the min learning rate
                        // defined above or the maximum number of epochs as been executed (defaulted to 10000).
                        LossMulticlassLog.Train(trainer, trainingImages, trainingLabels);

                        // At this point our net object should have learned how to classify MNIST images.  But
                        // before we try it out let's save it to disk.  Note that, since the trainer has been
                        // running images through the network, net will have a bunch of state in it related to
                        // the last batch of images it processed (e.g. outputs from each layer).  Since we
                        // don't care about saving that kind of stuff to disk we can tell the network to forget
                        // about that kind of transient data so that our file will be smaller.  We do this by
                        // "cleaning" the network before saving it.
                        net.Clean();
                        LossMulticlassLog.Serialize(net, "mnist_network.dat");
                        // Now if we later wanted to recall the network from disk we can simply say:
                        // deserialize("mnist_network.dat") >> net;


                        // Now let's run the training images through the network.  This statement runs all the
                        // images through it and asks the loss layer to convert the network's raw output into
                        // labels.  In our case, these labels are the numbers between 0 and 9.
                        using (var predictedLabels = net.Operator(trainingImages))
                        {
                            var numRight = 0;
                            var numWrong = 0;
                            // And then let's see if it classified them correctly.
                            for (var i = 0; i < trainingImages.Count; ++i)
                            {
                                if (predictedLabels[i] == trainingLabels[i])
                                {
                                    ++numRight;
                                }
                                else
                                {
                                    ++numWrong;
                                }
                            }

                            Console.WriteLine($"training num_right: {numRight}");
                            Console.WriteLine($"training num_wrong: {numWrong}");
                            Console.WriteLine($"training accuracy:  {numRight / (double)(numRight + numWrong)}");

                            // Let's also see if the network can correctly classify the testing images.  Since
                            // MNIST is an easy dataset, we should see at least 99% accuracy.
                            using (var predictedLabels2 = net.Operator(testingImages))
                            {
                                numRight = 0;
                                numWrong = 0;
                                for (var i = 0; i < testingImages.Count; ++i)
                                {
                                    if (predictedLabels2[i] == testingLabels[i])
                                    {
                                        ++numRight;
                                    }
                                    else
                                    {
                                        ++numWrong;
                                    }
                                }

                                Console.WriteLine($"testing num_right: {numRight}");
                                Console.WriteLine($"testing num_wrong: {numWrong}");
                                Console.WriteLine($"testing accuracy:  {numRight / (double)(numRight + numWrong)}");


                                // Finally, you can also save network parameters to XML files if you want to do
                                // something with the network in another tool.  For example, you could use dlib's
                                // tools/convert_dlib_nets_to_caffe to convert the network to a caffe model.
                                Dlib.NetToXml(net, "lenet.xml");
                            }
                        }
                    }
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e.Message);
            }

            return(0);
        }
Esempio n. 12
0
        private static void Main(string[] args)
        {
            try
            {
                // This example is going to run on the MNIST dataset.
                if (args.Length != 1)
                {
                    Console.WriteLine("This example needs the MNIST dataset to run!");
                    Console.WriteLine("You can get MNIST from http://yann.lecun.com/exdb/mnist/");
                    Console.WriteLine("Download the 4 files that comprise the dataset, decompress them, and");
                    Console.WriteLine("put them in a folder.  Then give that folder as input to this program.");
                    return;
                }

                Dlib.LoadMNISTDataset(args[0],
                                      out var trainingImages,
                                      out var trainingLabels,
                                      out var testingImages,
                                      out var testingLabels);


                // Make an instance of our inception network.
                using (var net = new LossMulticlassLog())
                {
                    Console.WriteLine($"The net has {net.NumLayers} layers in it.");
                    Console.WriteLine(net);

                    Console.WriteLine("Traning NN...");
                    using (var trainer = new DnnTrainer <LossMulticlassLog>(net))
                    {
                        trainer.SetLearningRate(0.01);
                        trainer.SetMinLearningRate(0.00001);
                        trainer.SetMinBatchSize(128);
                        trainer.BeVerbose();
                        trainer.SetSynchronizationFile("inception_sync", 20);
                        // Train the network.  This might take a few minutes...
                        LossMulticlassLog.Train(trainer, trainingImages, trainingLabels);

                        // At this point our net object should have learned how to classify MNIST images.  But
                        // before we try it out let's save it to disk.  Note that, since the trainer has been
                        // running images through the network, net will have a bunch of state in it related to
                        // the last batch of images it processed (e.g. outputs from each layer).  Since we
                        // don't care about saving that kind of stuff to disk we can tell the network to forget
                        // about that kind of transient data so that our file will be smaller.  We do this by
                        // "cleaning" the network before saving it.
                        net.Clean();
                        LossMulticlassLog.Serialize(net, "mnist_network_inception.dat");
                        // Now if we later wanted to recall the network from disk we can simply say:
                        // deserialize("mnist_network_inception.dat") >> net;


                        // Now let's run the training images through the network.  This statement runs all the
                        // images through it and asks the loss layer to convert the network's raw output into
                        // labels.  In our case, these labels are the numbers between 0 and 9.
                        using (var predictedLabels = net.Operator(trainingImages))
                        {
                            var numRight = 0;
                            var numWrong = 0;
                            // And then let's see if it classified them correctly.
                            for (var i = 0; i < trainingImages.Length; ++i)
                            {
                                if (predictedLabels[i] == trainingLabels[i])
                                {
                                    ++numRight;
                                }
                                else
                                {
                                    ++numWrong;
                                }
                            }

                            Console.WriteLine($"training num_right: {numRight}");
                            Console.WriteLine($"training num_wrong: {numWrong}");
                            Console.WriteLine($"training accuracy:  {numRight / (double)(numRight + numWrong)}");

                            // Let's also see if the network can correctly classify the testing images.
                            // Since MNIST is an easy dataset, we should see 99% accuracy.
                            using (var predictedLabels2 = net.Operator(testingImages))
                            {
                                numRight = 0;
                                numWrong = 0;
                                for (var i = 0; i < testingImages.Length; ++i)
                                {
                                    if (predictedLabels2[i] == testingLabels[i])
                                    {
                                        ++numRight;
                                    }
                                    else
                                    {
                                        ++numWrong;
                                    }
                                }

                                Console.WriteLine($"testing num_right: {numRight}");
                                Console.WriteLine($"testing num_wrong: {numWrong}");
                                Console.WriteLine($"testing accuracy:  {numRight / (double)(numRight + numWrong)}");
                            }
                        }
                    }
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
            }
        }