Пример #1
0
        public static void CNN_OverfittingTest()
        {
            CNN cnn = new CNN();

            DigitImage[] digitImages = MNIST_Parser.ReadFromFile(DataSet.Testing, 100);

            Stopwatch stopwatch = new Stopwatch();

            stopwatch.Start();

            int iteration_count = 10;

            for (int i = 0; i < iteration_count; i++)
            {
                for (int j = 0; j < digitImages.Length; ++j)
                {
                    Matrix[] input = new Matrix[1];
                    input[0] = new Matrix(digitImages[j].pixels);

                    Matrix target = new Matrix(10, 1);
                    target[(int)digitImages[j].label, 0] = 1f;

                    cnn.Train(input, target);
                    double error = cnn.GetError();

                    int val = (int)((i - 0) / (double)(iteration_count - 1 - 0) * (100 - 0) + 0);
                    ProgressBar(val, i, iteration_count, error, stopwatch.ElapsedMilliseconds / 1000.0);
                }
            }

            for (int j = 0; j < digitImages.Length; ++j)
            {
                Matrix[] input = new Matrix[1];
                input[0] = new Matrix(digitImages[j].pixels);

                Matrix output = cnn.Predict(input);

                Console.WriteLine(output.ToString());
                Console.WriteLine(digitImages[j].ToString());
                Console.ReadLine();
            }
        }
        private void Button_Next_Click(object sender, EventArgs e)
        {
            DigitImage[] digitImages = MNIST_Parser.ReadFromFile(DataSet.Training, 10000);

            int idx = numPool[rnd.Next(0, numPool.Count)];

            numPool.RemoveAt(idx);
            input[0] = new Matrix(digitImages[idx].pixels);

            Bitmap b = new Bitmap(28, 28);

            for (int i = 0; i < b.Width; i++)
            {
                for (int j = 0; j < b.Height; j++)
                {
                    b.SetPixel(i, j, Color.FromArgb((byte)(255 - input[0][j, i]), (byte)(255 - input[0][j, i]), (byte)(255 - input[0][j, i])));
                }
            }

            bmp = new Bitmap(b, new Size(bmp.Width, bmp.Height));
            panel1.Invalidate();

            Predict();
        }
Пример #3
0
        public static void CNN_Training()
        {
            int    trCount = 60000, tsCount = 10000;
            double error = 0f, timeLimit = 0f;

            int    totalEpoch     = 10;
            bool   predictionIsOn = true;
            Random random         = new Random();

            DigitImage[] trainigDigitImagesDatas = MNIST_Parser.ReadFromFile(DataSet.Training, trCount);
            DigitImage[] testingDigitImagesDatas = MNIST_Parser.ReadFromFile(DataSet.Testing, tsCount);
            int          training_count          = trainigDigitImagesDatas.Length;
            int          dialogResult;

            CNN cnn = CreateCNN(out dialogResult);

            Matrix[] input, targets;
            InitializeInputAndTarget(out input, out targets);

            Stopwatch stopwatch = new Stopwatch();

            stopwatch.Start();

            if (dialogResult == 0)
            {
                Console.WriteLine("System is getting trained...");
                //if we never assigned this assign only once
                if (cursorTopTraining == -1)
                {
                    cursorTopTraining = Console.CursorTop;
                }

                for (int epoch = 0; epoch < totalEpoch; epoch++)
                {
                    double       lossSum     = 0.0;
                    DigitImage[] digitImages = trainigDigitImagesDatas.OrderBy(image => random.Next(training_count)).ToArray();
                    for (int i = 0; i < training_count; i++)
                    {
                        for (int j = 0; j < 28; j++)
                        {
                            for (int k = 0; k < 28; k++)
                            {
                                input[0][j, k] = digitImages[i].pixels[j][k];
                            }
                        }

                        input[0].Normalize(0f, 255f, 0f, 1f);
                        cnn.Train(input, targets[digitImages[i].label]);

                        //if (stopwatch.ElapsedMilliseconds > timeLimit)
                        // {
                        // every 0.5 sec update error
                        //timeLimit += 500;
                        error    = cnn.GetError();
                        lossSum += error;

                        //}

                        int val = Map(0, training_count * totalEpoch, 0, 100, training_count * epoch + i);
                        ProgressBar(val, training_count * epoch + i, training_count * totalEpoch, lossSum / (i + 1), stopwatch.ElapsedMilliseconds / 1000.0, cursorTopTraining);
                    }
                    CNN_Testing(testingDigitImagesDatas, predictionIsOn, cnn, epoch + 1);
                }

                Console.WriteLine("\nSystem has been trained.");
            }
            else
            {
                CNN_Testing(testingDigitImagesDatas, predictionIsOn, cnn, -1);
            }
        }