public static void Run()
        {
            trainingData = ImageSample.LoadTrainingImages();   // 50,000 training images
            testingData  = ImageSample.LoadTestingImages();    // 10,000 testing images

            var net     = new NeuralNet(ImageWidthHeight * ImageWidthHeight, 20, 10);
            var trainer = new Trainer(net);

            var sw = new Stopwatch();

            sw.Start();
            trainer.Train(trainingData, testingData, learningRate: .01, epochs: 10);
            sw.Stop();
            Console.WriteLine($"Training Ellapsed: {sw.Elapsed.TotalSeconds} seconds = {sw.Elapsed.TotalSeconds / 60} minutes");

            var testInfos = GetImageTestInfo(new FiringNet(net), testingData).ToList();
            var failures  =
                (from testInfo in testInfos
                 where !testInfo.IsCorrect
                 select new { testInfo.ImageSample.Label, testInfo.TotalLoss, testInfo.OutputValues }).ToList();

            Console.WriteLine($"Test set accuracy: { (testInfos.Count - failures.Count) * 100 / testInfos.Count}");

            Console.WriteLine("Failures with highest loss");
            foreach (var f in failures.OrderByDescending(f => f.TotalLoss).Take(100))
            {
                Console.WriteLine($"Label: {f.Label}, Predicted: {HelperMethods.IndexOfMax(f.OutputValues)}, TotalLoss: {f.TotalLoss}");
            }
        }
Beispiel #2
0
        async Task Main()
        {
            trainingData = ImageSample.LoadTrainingImages();               // 50,000 training images
            testingData  = ImageSample.LoadTestingImages();                // 10,000 testing images

            var net = new NeuralNet(ImageWidthHeight * ImageWidthHeight, 50, 10);

            var trainer = new Trainer(net);

            await Task.Run(() => trainer.Train(trainingData, testingData, learningRate: .01, epochs: 10));

            var failures =
                from testInfo in GetImageTestInfo(new FiringNet(net), testingData)
                where !testInfo.IsCorrect
                select new { testInfo.Image, testInfo.ImageSample.Label, testInfo.TotalLoss, testInfo.OutputValues };

            failures.OrderByDescending(f => f.TotalLoss).Take(100).PrintDump();
        }