public static void CNN_OverfittingTest() { CNN cnn = new CNN(); DigitImage[] digitImages = MNIST_Parser.ReadFromFile(DataSet.Testing, 100); Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); int iteration_count = 10; for (int i = 0; i < iteration_count; i++) { for (int j = 0; j < digitImages.Length; ++j) { Matrix[] input = new Matrix[1]; input[0] = new Matrix(digitImages[j].pixels); Matrix target = new Matrix(10, 1); target[(int)digitImages[j].label, 0] = 1f; cnn.Train(input, target); double error = cnn.GetError(); int val = (int)((i - 0) / (double)(iteration_count - 1 - 0) * (100 - 0) + 0); ProgressBar(val, i, iteration_count, error, stopwatch.ElapsedMilliseconds / 1000.0); } } for (int j = 0; j < digitImages.Length; ++j) { Matrix[] input = new Matrix[1]; input[0] = new Matrix(digitImages[j].pixels); Matrix output = cnn.Predict(input); Console.WriteLine(output.ToString()); Console.WriteLine(digitImages[j].ToString()); Console.ReadLine(); } }
private void Button_Next_Click(object sender, EventArgs e) { DigitImage[] digitImages = MNIST_Parser.ReadFromFile(DataSet.Training, 10000); int idx = numPool[rnd.Next(0, numPool.Count)]; numPool.RemoveAt(idx); input[0] = new Matrix(digitImages[idx].pixels); Bitmap b = new Bitmap(28, 28); for (int i = 0; i < b.Width; i++) { for (int j = 0; j < b.Height; j++) { b.SetPixel(i, j, Color.FromArgb((byte)(255 - input[0][j, i]), (byte)(255 - input[0][j, i]), (byte)(255 - input[0][j, i]))); } } bmp = new Bitmap(b, new Size(bmp.Width, bmp.Height)); panel1.Invalidate(); Predict(); }
public static void CNN_Training() { int trCount = 60000, tsCount = 10000; double error = 0f, timeLimit = 0f; int totalEpoch = 10; bool predictionIsOn = true; Random random = new Random(); DigitImage[] trainigDigitImagesDatas = MNIST_Parser.ReadFromFile(DataSet.Training, trCount); DigitImage[] testingDigitImagesDatas = MNIST_Parser.ReadFromFile(DataSet.Testing, tsCount); int training_count = trainigDigitImagesDatas.Length; int dialogResult; CNN cnn = CreateCNN(out dialogResult); Matrix[] input, targets; InitializeInputAndTarget(out input, out targets); Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); if (dialogResult == 0) { Console.WriteLine("System is getting trained..."); //if we never assigned this assign only once if (cursorTopTraining == -1) { cursorTopTraining = Console.CursorTop; } for (int epoch = 0; epoch < totalEpoch; epoch++) { double lossSum = 0.0; DigitImage[] digitImages = trainigDigitImagesDatas.OrderBy(image => random.Next(training_count)).ToArray(); for (int i = 0; i < training_count; i++) { for (int j = 0; j < 28; j++) { for (int k = 0; k < 28; k++) { input[0][j, k] = digitImages[i].pixels[j][k]; } } input[0].Normalize(0f, 255f, 0f, 1f); cnn.Train(input, targets[digitImages[i].label]); //if (stopwatch.ElapsedMilliseconds > timeLimit) // { // every 0.5 sec update error //timeLimit += 500; error = cnn.GetError(); lossSum += error; //} int val = Map(0, training_count * totalEpoch, 0, 100, training_count * epoch + i); ProgressBar(val, training_count * epoch + i, training_count * totalEpoch, lossSum / (i + 1), stopwatch.ElapsedMilliseconds / 1000.0, cursorTopTraining); } CNN_Testing(testingDigitImagesDatas, predictionIsOn, cnn, epoch + 1); } Console.WriteLine("\nSystem has been trained."); } else { CNN_Testing(testingDigitImagesDatas, predictionIsOn, cnn, -1); } }