示例#1
0
        private static void Load()
        {
            // Load data
            Console.WriteLine("Loading the dataset...");
            Bitmap bmp = new Bitmap(Test_SimpleNumbers.file);

            for (int numberIdx = 0; numberIdx < 10; numberIdx++)
            {
                byte[] data = new byte[3 * 5];
                for (int y = 0; y < 5; y++)
                {
                    for (int x = 0; x < 3; x++)
                    {
                        Color color = bmp.GetPixel(numberIdx * 3 + x, y);
                        data[x + y * 3] = color.R;
                    }
                }

                MnistEntry entry = new MnistEntry(data, numberIdx);
                Test_SimpleNumbers.testing.Add(entry);
            }

            if (Test_SimpleNumbers.testing.Count == 0)
            {
                Console.WriteLine("Missing file.");
                Console.ReadKey();
                return;
            }
        }
示例#2
0
        private static void Test(NeuralNetwork nn, int tests = -1)
        {
            tests = tests == -1 ? testing.Count : tests;
            tests = Math.Min(testing.Count, tests);

            int score = 0;

            for (int i = 0; i < tests; i++)
            {
                Console.WriteLine($"Starting Test {i}/{tests}");

                MnistEntry entry = testing[i];

                NNFeedData inputData  = new NNFeedData(3, 5, ConvertArray(entry.Image));
                NNFeedData outputData = nn.FeedForward(inputData);

                (int label, float value)guess = ArrayToLabel(outputData.CopyData());
                if (guess.label == entry.Label)
                {
                    score++;
                }

                Console.WriteLine($"{entry.Label} | {guess.label} ({guess.value:F2})");
            }

            Console.WriteLine($"{score}/{tests}");
        }
示例#3
0
        private static void Train(NeuralNetwork nn, float learningRate, int runs)
        {
            Console.WriteLine("Preparing Training...");
            NNFeedData[] trainingInputData  = new NNFeedData[training.Count];
            NNFeedData[] trainingTargetData = new NNFeedData[training.Count];
            for (int i = 0; i < training.Count; i++)
            {
                MnistEntry entry = training[i];

                trainingInputData[i]  = new NNFeedData(28, 28, 1, ConvertArray(entry.Image));
                trainingTargetData[i] = new NNFeedData(10, 1, 1, LabelToArray(entry.Label));
            }
            NNTrainingData trainingData = new NNPreloadedTrainingData(trainingInputData, trainingTargetData);

            NNBackpropagationData backpropagationData = new NNBackpropagationData(trainingData, learningRate, (o, t) => o - t);

            double totalTime = 0;

            Console.WriteLine("Starting Training...");
            for (int trainingRuns = 0; trainingRuns < runs; trainingRuns++)
            {
                DateTime start = DateTime.UtcNow;

                backpropagationData.BatchTrainingStartingCallback = (trainingDataIndex, trainingSets) => {
                    start = DateTime.UtcNow;
                };
                backpropagationData.BatchTrainingFinishedCallback = (trainingDataIndex, trainingSets) => {
                    totalTime += (DateTime.UtcNow - start).TotalMilliseconds;
                    double   avgTime       = totalTime / (trainingDataIndex + 1);
                    double   remainingTime = avgTime * (trainingSets - trainingDataIndex);
                    TimeSpan avgTSpan      = TimeSpan.FromMilliseconds(avgTime);
                    TimeSpan remTSpan      = TimeSpan.FromMilliseconds(remainingTime);
                    TimeSpan totTSpan      = TimeSpan.FromMilliseconds(totalTime);
                    string   avgTS         = $"{avgTSpan:ss\\.ffff}";
                    string   remTS         = $"{remTSpan:hh\\:mm\\:ss}";
                    string   totTS         = $"{totTSpan:hh\\:mm\\:ss}";

                    Save(nn, $"primary_{trainingDataIndex}");

                    Console.WriteLine($"Finished Training {trainingDataIndex}/{trainingSets} Passed:{totTS} Remaining:{remTS} Avg:{avgTS}");
                };


                nn.PropagateBackward(backpropagationData);
            }
        }