コード例 #1
0
        public void ClassificationByDQN()
        {
            double _loss = 1.0;
            //
            GRasterLayer featureLayer = new GRasterLayer(featureFullFilename);
            GRasterLayer labelLayer   = new GRasterLayer(trainFullFilename);
            //create environment for agent exploring
            IEnv env = new ImageClassifyEnv(featureLayer, labelLayer);
            //create dqn alogrithm
            DQN dqn = new DQN(env);

            //in order to do this quickly, we set training epochs equals 10.
            //please do not use so few training steps in actual use.
            dqn.SetParameters(10, 0);
            //register event to get information while training
            dqn.OnLearningLossEventHandler += (double loss, double totalReward, double accuracy, double progress, string epochesTime) => { _loss = loss; };
            //start dqn alogrithm learning
            dqn.Learn();
            //in general, loss is less than 1
            Assert.IsTrue(_loss < 1.0);
            //apply dqn to classify fetureLayer
            //pick value
            IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();

            pRasterLayerCursorTool.Visit(featureLayer);
            //
            double[] state         = pRasterLayerCursorTool.PickNormalValue(50, 50);
            double[] action        = dqn.ChooseAction(state).action;
            int      landCoverType = dqn.ActionToRawValue(NP.Argmax(action));

            //do something as you need. i.e. draw landCoverType to bitmap at position ( i , j )
            //the classification results are not stable because of the training epochs are too few.
            Assert.IsTrue(landCoverType >= 0);
        }
コード例 #2
0
 public JobCNNClassify(GRasterLayer featureRasterLayer, GRasterLayer labelRasterLayer, int epochs, int model, int width, int height, int channel)
 {
     _t = new Thread(() => {
         ImageClassifyEnv env = new ImageClassifyEnv(featureRasterLayer, labelRasterLayer);
         CNN cnn = new CNN(new int[] { channel, width, height }, env.ActionNum);
         //training
         Summary = "模型训练中";
         for (int i = 0; i < epochs; i++)
         {
             int batchSize       = cnn.BatchSize;
             var(states, labels) = env.RandomEval(batchSize);
             double[][] inputX   = new double[batchSize][];
             for (int j = 0; j < batchSize; j++)
             {
                 inputX[j] = states[j];
             }
             double loss = cnn.Train(inputX, labels);
             Process     = (double)i / epochs;
         }
         //classify
         Summary = "分类应用中";
         IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();
         pRasterLayerCursorTool.Visit(featureRasterLayer);
         //GDI graph
         Bitmap classificationBitmap = new Bitmap(featureRasterLayer.XSize, featureRasterLayer.YSize);
         Graphics g = Graphics.FromImage(classificationBitmap);
         //
         int seed        = 0;
         int totalPixels = featureRasterLayer.XSize * featureRasterLayer.YSize;
         //应用dqn对图像分类
         for (int i = 0; i < featureRasterLayer.XSize; i++)
         {
             for (int j = 0; j < featureRasterLayer.YSize; j++)
             {
                 //get normalized input raw value
                 double[] normal = pRasterLayerCursorTool.PickNormalValue(i, j);
                 //}{debug
                 double[] action = cnn.Predict(normal);
                 //convert action to raw byte value
                 int gray = env.RandomSeedKeys[NP.Argmax(action)];
                 //后台绘制,报告进度
                 Color c          = Color.FromArgb(gray, gray, gray);
                 Pen p            = new Pen(c);
                 SolidBrush brush = new SolidBrush(c);
                 g.FillRectangle(brush, new Rectangle(i, j, 1, 1));
                 //report progress
                 Process = (double)(seed++) / totalPixels;
             }
         }
         //保存结果至tmp
         string fullFileName = Directory.GetCurrentDirectory() + @"\tmp\" + DateTime.Now.ToFileTimeUtc() + ".png";
         classificationBitmap.Save(fullFileName);
         //complete
         Summary  = "CNN训练分类完成";
         Complete = true;
         OnTaskComplete?.Invoke(Name, fullFileName);
     });
 }
コード例 #3
0
 /// <summary>
 /// DQN classify task
 /// </summary>
 /// <param name="featureRasterLayer"></param>
 /// <param name="labelRasterLayer"></param>
 /// <param name="epochs"></param>
 public JobDQNClassify(GRasterLayer featureRasterLayer, GRasterLayer labelRasterLayer, int epochs = 3000)
 {
     _t = new Thread(() =>
     {
         ImageClassifyEnv env = new ImageClassifyEnv(featureRasterLayer, labelRasterLayer);
         _dqn = new DQN(env);
         _dqn.SetParameters(epochs: epochs, gamma: _gamma);
         _dqn.OnLearningLossEventHandler += _dqn_OnLearningLossEventHandler;
         //training
         Summary = "模型训练中";
         _dqn.Learn();
         //classification
         Summary = "分类应用中";
         IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();
         pRasterLayerCursorTool.Visit(featureRasterLayer);
         Bitmap classificationBitmap = new Bitmap(featureRasterLayer.XSize, featureRasterLayer.YSize);
         Graphics g      = Graphics.FromImage(classificationBitmap);
         int seed        = 0;
         int totalPixels = featureRasterLayer.XSize * featureRasterLayer.YSize;
         for (int i = 0; i < featureRasterLayer.XSize; i++)
         {
             for (int j = 0; j < featureRasterLayer.YSize; j++)
             {
                 //get normalized input raw value
                 double[] normal = pRasterLayerCursorTool.PickNormalValue(i, j);
                 var(action, q)  = _dqn.ChooseAction(normal);
                 //convert action to raw byte value
                 int gray         = _dqn.ActionToRawValue(NP.Argmax(action));
                 Color c          = Color.FromArgb(gray, gray, gray);
                 Pen p            = new Pen(c);
                 SolidBrush brush = new SolidBrush(c);
                 g.FillRectangle(brush, new Rectangle(i, j, 1, 1));
                 //report progress
                 Process = (double)(seed++) / totalPixels;
             }
         }
         //save result
         string fullFileName = Directory.GetCurrentDirectory() + @"\tmp\" + DateTime.Now.ToFileTimeUtc() + ".png";
         classificationBitmap.Save(fullFileName);
         //complete
         Summary  = "DQN训练分类完成";
         Complete = true;
         OnTaskComplete?.Invoke(Name, fullFileName);
     });
 }
コード例 #4
0
        public void ClassificationByCNN()
        {
            //loss
            double _loss = 1.0;
            //training epochs
            int epochs = 100;
            //
            GRasterLayer featureLayer = new GRasterLayer(featureFullFilename);
            GRasterLayer labelLayer   = new GRasterLayer(trainFullFilename);
            //create environment for agent exploring
            IEnv env = new ImageClassifyEnv(featureLayer, labelLayer);
            //assume 18dim equals 3x6 (image)
            CNN cnn = new CNN(new int[] { 1, 3, 6 }, env.ActionNum);

            //training
            for (int i = 0; i < epochs; i++)
            {
                int batchSize = cnn.BatchSize;
                var(states, labels) = env.RandomEval(batchSize);
                double[][] inputX = new double[batchSize][];
                for (int j = 0; j < batchSize; j++)
                {
                    inputX[j] = states[j];
                }
                _loss = cnn.Train(inputX, labels);
            }
            //in general, loss is less than 5
            Assert.IsTrue(_loss < 5.0);
            //apply cnn to classify featureLayer
            IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();

            pRasterLayerCursorTool.Visit(featureLayer);
            //get normalized input raw value
            double[] normal        = pRasterLayerCursorTool.PickNormalValue(50, 50);
            double[] action        = cnn.Predict(normal);
            int      landCoverType = env.RandomSeedKeys[NP.Argmax(action)];

            //do something as you need. i.e. draw landCoverType to bitmap at position ( i , j )
            //the classification results are not stable because of the training epochs are too few.
            Assert.IsTrue(landCoverType >= 0);
        }