Ejemplo n.º 1
0
        public static void Kaggle_Test()
        {
            //DigitImage[] training = CSV_Helper.Read(DataSet.Training);
            DigitImage[] testing = CSV_Helper.Read(DataSet.Testing);

            int dialogResult;
            CNN cnn = CreateCNN(out dialogResult);

            DigitImage[] digitImagesDatas = testing;

            Stopwatch stopwatch = new Stopwatch();

            stopwatch.Start();

            int    testing_count = digitImagesDatas.Length;
            int    correct_count = 0;
            double timeLimit = 0, error = 0;

            Matrix[] input, targets;

            Console.WriteLine("System is getting tested. You will see the results when it is done...\n");
            if (cursorTopTesting == 0)
            {
                cursorTopTesting = Console.CursorTop;
            }

            InitializeInputAndTarget(out input, out targets);
            timeLimit = stopwatch.ElapsedMilliseconds;

            string submissionCSV = "ImageId,Label" + Environment.NewLine;

            for (int i = 0; i < testing_count; i++)
            {
                for (int j = 0; j < 28; j++)
                {
                    for (int k = 0; k < 28; k++)
                    {
                        input[0][j, k] = digitImagesDatas[i].pixels[j][k];
                    }
                }

                input[0].Normalize(0f, 255f, 0f, 1f);

                Matrix ans = cnn.Predict(input);

                submissionCSV += (i + 1).ToString() + "," + ans.GetMaxRowIndex() + Environment.NewLine;

                if (stopwatch.ElapsedMilliseconds > timeLimit)
                {
                    // every 0.5 sec update error
                    timeLimit += 500;
                    error      = cnn.GetError();
                }

                int val = Map(0, testing_count, 0, 100, i);
                ProgressBar(val, i, testing_count, error, stopwatch.ElapsedMilliseconds / 1000.0, cursorTopTesting);
            }

            CSV_Helper.Write(submissionCSV);
        }
Ejemplo n.º 2
0
 public JobCNNClassify(GRasterLayer featureRasterLayer, GRasterLayer labelRasterLayer, int epochs, int model, int width, int height, int channel)
 {
     _t = new Thread(() => {
         ImageClassifyEnv env = new ImageClassifyEnv(featureRasterLayer, labelRasterLayer);
         CNN cnn = new CNN(new int[] { channel, width, height }, env.ActionNum);
         //training
         Summary = "模型训练中";
         for (int i = 0; i < epochs; i++)
         {
             int batchSize       = cnn.BatchSize;
             var(states, labels) = env.RandomEval(batchSize);
             double[][] inputX   = new double[batchSize][];
             for (int j = 0; j < batchSize; j++)
             {
                 inputX[j] = states[j];
             }
             double loss = cnn.Train(inputX, labels);
             Process     = (double)i / epochs;
         }
         //classify
         Summary = "分类应用中";
         IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();
         pRasterLayerCursorTool.Visit(featureRasterLayer);
         //GDI graph
         Bitmap classificationBitmap = new Bitmap(featureRasterLayer.XSize, featureRasterLayer.YSize);
         Graphics g = Graphics.FromImage(classificationBitmap);
         //
         int seed        = 0;
         int totalPixels = featureRasterLayer.XSize * featureRasterLayer.YSize;
         //应用dqn对图像分类
         for (int i = 0; i < featureRasterLayer.XSize; i++)
         {
             for (int j = 0; j < featureRasterLayer.YSize; j++)
             {
                 //get normalized input raw value
                 double[] normal = pRasterLayerCursorTool.PickNormalValue(i, j);
                 //}{debug
                 double[] action = cnn.Predict(normal);
                 //convert action to raw byte value
                 int gray = env.RandomSeedKeys[NP.Argmax(action)];
                 //后台绘制,报告进度
                 Color c          = Color.FromArgb(gray, gray, gray);
                 Pen p            = new Pen(c);
                 SolidBrush brush = new SolidBrush(c);
                 g.FillRectangle(brush, new Rectangle(i, j, 1, 1));
                 //report progress
                 Process = (double)(seed++) / totalPixels;
             }
         }
         //保存结果至tmp
         string fullFileName = Directory.GetCurrentDirectory() + @"\tmp\" + DateTime.Now.ToFileTimeUtc() + ".png";
         classificationBitmap.Save(fullFileName);
         //complete
         Summary  = "CNN训练分类完成";
         Complete = true;
         OnTaskComplete?.Invoke(Name, fullFileName);
     });
 }
Ejemplo n.º 3
0
        public static void CNN_OverfittingTest()
        {
            CNN cnn = new CNN();

            DigitImage[] digitImages = MNIST_Parser.ReadFromFile(DataSet.Testing, 100);

            Stopwatch stopwatch = new Stopwatch();

            stopwatch.Start();

            int iteration_count = 10;

            for (int i = 0; i < iteration_count; i++)
            {
                for (int j = 0; j < digitImages.Length; ++j)
                {
                    Matrix[] input = new Matrix[1];
                    input[0] = new Matrix(digitImages[j].pixels);

                    Matrix target = new Matrix(10, 1);
                    target[(int)digitImages[j].label, 0] = 1f;

                    cnn.Train(input, target);
                    double error = cnn.GetError();

                    int val = (int)((i - 0) / (double)(iteration_count - 1 - 0) * (100 - 0) + 0);
                    ProgressBar(val, i, iteration_count, error, stopwatch.ElapsedMilliseconds / 1000.0);
                }
            }

            for (int j = 0; j < digitImages.Length; ++j)
            {
                Matrix[] input = new Matrix[1];
                input[0] = new Matrix(digitImages[j].pixels);

                Matrix output = cnn.Predict(input);

                Console.WriteLine(output.ToString());
                Console.WriteLine(digitImages[j].ToString());
                Console.ReadLine();
            }
        }
Ejemplo n.º 4
0
        public void ClassificationByCNN()
        {
            //loss
            double _loss = 1.0;
            //training epochs
            int epochs = 100;
            //
            GRasterLayer featureLayer = new GRasterLayer(featureFullFilename);
            GRasterLayer labelLayer   = new GRasterLayer(trainFullFilename);
            //create environment for agent exploring
            IEnv env = new ImageClassifyEnv(featureLayer, labelLayer);
            //assume 18dim equals 3x6 (image)
            CNN cnn = new CNN(new int[] { 1, 3, 6 }, env.ActionNum);

            //training
            for (int i = 0; i < epochs; i++)
            {
                int batchSize = cnn.BatchSize;
                var(states, labels) = env.RandomEval(batchSize);
                double[][] inputX = new double[batchSize][];
                for (int j = 0; j < batchSize; j++)
                {
                    inputX[j] = states[j];
                }
                _loss = cnn.Train(inputX, labels);
            }
            //in general, loss is less than 5
            Assert.IsTrue(_loss < 5.0);
            //apply cnn to classify featureLayer
            IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();

            pRasterLayerCursorTool.Visit(featureLayer);
            //get normalized input raw value
            double[] normal        = pRasterLayerCursorTool.PickNormalValue(50, 50);
            double[] action        = cnn.Predict(normal);
            int      landCoverType = env.RandomSeedKeys[NP.Argmax(action)];

            //do something as you need. i.e. draw landCoverType to bitmap at position ( i , j )
            //the classification results are not stable because of the training epochs are too few.
            Assert.IsTrue(landCoverType >= 0);
        }
Ejemplo n.º 5
0
        public static void CNN_Testing(DigitImage[] digitImagesDatas, bool predictionIsOn, CNN cnn, int iterationCount)
        {
            Stopwatch stopwatch = new Stopwatch();

            stopwatch.Start();

            int    testing_count = digitImagesDatas.Length;
            int    correct_count = 0;
            double timeLimit = 0, error = 0;

            Matrix[] input, targets;

            Console.WriteLine("System is getting tested. You will see the results when it is done...\n");
            if (cursorTopTesting == 0)
            {
                cursorTopTesting = Console.CursorTop;
            }

            InitializeInputAndTarget(out input, out targets);
            timeLimit = stopwatch.ElapsedMilliseconds;

            for (int i = 0; i < testing_count; i++)
            {
                for (int j = 0; j < 28; j++)
                {
                    for (int k = 0; k < 28; k++)
                    {
                        input[0][j, k] = digitImagesDatas[i].pixels[j][k];
                    }
                }

                input[0].Normalize(0f, 255f, 0f, 1f);


                Matrix ans = null;
                if (predictionIsOn)
                {
                    ans = cnn.Predict(input);
                }
                else
                {
                    cnn.Train(input, targets[digitImagesDatas[i].label]);
                    ans = cnn.Layers[cnn.Layers.Length - 1].Output[0];
                }

                if (ans.GetMaxRowIndex() == digitImagesDatas[i].label)
                {
                    correct_count++;
                }

                if (stopwatch.ElapsedMilliseconds > timeLimit)
                {
                    // every 0.5 sec update error
                    timeLimit += 500;
                    error      = cnn.GetError();
                }

                int val = Map(0, testing_count, 0, 100, i);
                ProgressBar(val, i, testing_count, error, stopwatch.ElapsedMilliseconds / 1000.0, cursorTopTesting);
            }
            double accuracy = (correct_count * 1f / testing_count) * 100.0;

            Console.WriteLine("\nIteration Count:" + iterationCount);
            Console.WriteLine("\nTime :" + (stopwatch.ElapsedMilliseconds / 1000.0).ToString("F4"));
            Console.WriteLine("\nAccuracy: %{0:F2}\n", accuracy);
            Console.WriteLine("Correct/All: {0}/{1}", correct_count, testing_count);

            cursorTopTesting = Console.CursorTop;

            if (accuracy >= 95 && iterationCount != -1)
            {
                string name   = accuracy.ToString("F2") + "__" + iterationCount + "__";
                Random random = new Random();
                int    length = 6;
                for (int i = 0; i < length; i++)
                {
                    char c = (char)random.Next('A', 'F' + 1);
                    name += c;
                }
                cnn.SaveData(name + ".json");
            }
        }