public JobCNNClassify(GRasterLayer featureRasterLayer, GRasterLayer labelRasterLayer, int epochs, int model, int width, int height, int channel) { _t = new Thread(() => { ImageClassifyEnv env = new ImageClassifyEnv(featureRasterLayer, labelRasterLayer); CNN cnn = new CNN(new int[] { channel, width, height }, env.ActionNum); //training Summary = "模型训练中"; for (int i = 0; i < epochs; i++) { int batchSize = cnn.BatchSize; var(states, labels) = env.RandomEval(batchSize); double[][] inputX = new double[batchSize][]; for (int j = 0; j < batchSize; j++) { inputX[j] = states[j]; } double loss = cnn.Train(inputX, labels); Process = (double)i / epochs; } //classify Summary = "分类应用中"; IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool(); pRasterLayerCursorTool.Visit(featureRasterLayer); //GDI graph Bitmap classificationBitmap = new Bitmap(featureRasterLayer.XSize, featureRasterLayer.YSize); Graphics g = Graphics.FromImage(classificationBitmap); // int seed = 0; int totalPixels = featureRasterLayer.XSize * featureRasterLayer.YSize; //应用dqn对图像分类 for (int i = 0; i < featureRasterLayer.XSize; i++) { for (int j = 0; j < featureRasterLayer.YSize; j++) { //get normalized input raw value double[] normal = pRasterLayerCursorTool.PickNormalValue(i, j); //}{debug double[] action = cnn.Predict(normal); //convert action to raw byte value int gray = env.RandomSeedKeys[NP.Argmax(action)]; //后台绘制,报告进度 Color c = Color.FromArgb(gray, gray, gray); Pen p = new Pen(c); SolidBrush brush = new SolidBrush(c); g.FillRectangle(brush, new Rectangle(i, j, 1, 1)); //report progress Process = (double)(seed++) / totalPixels; } } //保存结果至tmp string fullFileName = Directory.GetCurrentDirectory() + @"\tmp\" + DateTime.Now.ToFileTimeUtc() + ".png"; classificationBitmap.Save(fullFileName); //complete Summary = "CNN训练分类完成"; Complete = true; OnTaskComplete?.Invoke(Name, fullFileName); }); }
public static void CNN_OverfittingTest() { CNN cnn = new CNN(); DigitImage[] digitImages = MNIST_Parser.ReadFromFile(DataSet.Testing, 100); Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); int iteration_count = 10; for (int i = 0; i < iteration_count; i++) { for (int j = 0; j < digitImages.Length; ++j) { Matrix[] input = new Matrix[1]; input[0] = new Matrix(digitImages[j].pixels); Matrix target = new Matrix(10, 1); target[(int)digitImages[j].label, 0] = 1f; cnn.Train(input, target); double error = cnn.GetError(); int val = (int)((i - 0) / (double)(iteration_count - 1 - 0) * (100 - 0) + 0); ProgressBar(val, i, iteration_count, error, stopwatch.ElapsedMilliseconds / 1000.0); } } for (int j = 0; j < digitImages.Length; ++j) { Matrix[] input = new Matrix[1]; input[0] = new Matrix(digitImages[j].pixels); Matrix output = cnn.Predict(input); Console.WriteLine(output.ToString()); Console.WriteLine(digitImages[j].ToString()); Console.ReadLine(); } }
public void ClassificationByCNN() { //loss double _loss = 1.0; //training epochs int epochs = 100; // GRasterLayer featureLayer = new GRasterLayer(featureFullFilename); GRasterLayer labelLayer = new GRasterLayer(trainFullFilename); //create environment for agent exploring IEnv env = new ImageClassifyEnv(featureLayer, labelLayer); //assume 18dim equals 3x6 (image) CNN cnn = new CNN(new int[] { 1, 3, 6 }, env.ActionNum); //training for (int i = 0; i < epochs; i++) { int batchSize = cnn.BatchSize; var(states, labels) = env.RandomEval(batchSize); double[][] inputX = new double[batchSize][]; for (int j = 0; j < batchSize; j++) { inputX[j] = states[j]; } _loss = cnn.Train(inputX, labels); } //in general, loss is less than 5 Assert.IsTrue(_loss < 5.0); //apply cnn to classify featureLayer IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool(); pRasterLayerCursorTool.Visit(featureLayer); //get normalized input raw value double[] normal = pRasterLayerCursorTool.PickNormalValue(50, 50); double[] action = cnn.Predict(normal); int landCoverType = env.RandomSeedKeys[NP.Argmax(action)]; //do something as you need. i.e. draw landCoverType to bitmap at position ( i , j ) //the classification results are not stable because of the training epochs are too few. Assert.IsTrue(landCoverType >= 0); }
public static void CNN_Training() { int trCount = 60000, tsCount = 10000; double error = 0f, timeLimit = 0f; int totalEpoch = 10; bool predictionIsOn = true; Random random = new Random(); DigitImage[] trainigDigitImagesDatas = MNIST_Parser.ReadFromFile(DataSet.Training, trCount); DigitImage[] testingDigitImagesDatas = MNIST_Parser.ReadFromFile(DataSet.Testing, tsCount); int training_count = trainigDigitImagesDatas.Length; int dialogResult; CNN cnn = CreateCNN(out dialogResult); Matrix[] input, targets; InitializeInputAndTarget(out input, out targets); Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); if (dialogResult == 0) { Console.WriteLine("System is getting trained..."); //if we never assigned this assign only once if (cursorTopTraining == -1) { cursorTopTraining = Console.CursorTop; } for (int epoch = 0; epoch < totalEpoch; epoch++) { double lossSum = 0.0; DigitImage[] digitImages = trainigDigitImagesDatas.OrderBy(image => random.Next(training_count)).ToArray(); for (int i = 0; i < training_count; i++) { for (int j = 0; j < 28; j++) { for (int k = 0; k < 28; k++) { input[0][j, k] = digitImages[i].pixels[j][k]; } } input[0].Normalize(0f, 255f, 0f, 1f); cnn.Train(input, targets[digitImages[i].label]); //if (stopwatch.ElapsedMilliseconds > timeLimit) // { // every 0.5 sec update error //timeLimit += 500; error = cnn.GetError(); lossSum += error; //} int val = Map(0, training_count * totalEpoch, 0, 100, training_count * epoch + i); ProgressBar(val, training_count * epoch + i, training_count * totalEpoch, lossSum / (i + 1), stopwatch.ElapsedMilliseconds / 1000.0, cursorTopTraining); } CNN_Testing(testingDigitImagesDatas, predictionIsOn, cnn, epoch + 1); } Console.WriteLine("\nSystem has been trained."); } else { CNN_Testing(testingDigitImagesDatas, predictionIsOn, cnn, -1); } }
public static void CNN_Testing(DigitImage[] digitImagesDatas, bool predictionIsOn, CNN cnn, int iterationCount) { Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); int testing_count = digitImagesDatas.Length; int correct_count = 0; double timeLimit = 0, error = 0; Matrix[] input, targets; Console.WriteLine("System is getting tested. You will see the results when it is done...\n"); if (cursorTopTesting == 0) { cursorTopTesting = Console.CursorTop; } InitializeInputAndTarget(out input, out targets); timeLimit = stopwatch.ElapsedMilliseconds; for (int i = 0; i < testing_count; i++) { for (int j = 0; j < 28; j++) { for (int k = 0; k < 28; k++) { input[0][j, k] = digitImagesDatas[i].pixels[j][k]; } } input[0].Normalize(0f, 255f, 0f, 1f); Matrix ans = null; if (predictionIsOn) { ans = cnn.Predict(input); } else { cnn.Train(input, targets[digitImagesDatas[i].label]); ans = cnn.Layers[cnn.Layers.Length - 1].Output[0]; } if (ans.GetMaxRowIndex() == digitImagesDatas[i].label) { correct_count++; } if (stopwatch.ElapsedMilliseconds > timeLimit) { // every 0.5 sec update error timeLimit += 500; error = cnn.GetError(); } int val = Map(0, testing_count, 0, 100, i); ProgressBar(val, i, testing_count, error, stopwatch.ElapsedMilliseconds / 1000.0, cursorTopTesting); } double accuracy = (correct_count * 1f / testing_count) * 100.0; Console.WriteLine("\nIteration Count:" + iterationCount); Console.WriteLine("\nTime :" + (stopwatch.ElapsedMilliseconds / 1000.0).ToString("F4")); Console.WriteLine("\nAccuracy: %{0:F2}\n", accuracy); Console.WriteLine("Correct/All: {0}/{1}", correct_count, testing_count); cursorTopTesting = Console.CursorTop; if (accuracy >= 95 && iterationCount != -1) { string name = accuracy.ToString("F2") + "__" + iterationCount + "__"; Random random = new Random(); int length = 6; for (int i = 0; i < length; i++) { char c = (char)random.Next('A', 'F' + 1); name += c; } cnn.SaveData(name + ".json"); } }