Beispiel #1
0
        static void Main(string[] args)
        {
            Env env = new CartPoleEnv();
            //https://github.com/rlcode/reinforcement-learning/blob/master/2-cartpole/1-dqn/cartpole_dqn.py

            var valueFunc = new DQN(env.ObservationSpace.Shape, env.ActionSpace.NumberOfValues(), new[] { 24, 24 }, 0.001f, 0.99f, 32, new ExperienceReplay(2000))
            {
                TargetModelUpdateInterval = 1
            };

            Agent agent = new AgentQL("dqn_cartpole", env, valueFunc)
            {
                Verbose          = true,
                RewardOnDone     = -100,
                EpsilonDecayMode = EEpsilonDecayMode.EveryStep,
                EpsilonDecay     = 0.999f,
                WarmupSteps      = 1000
            };

            agent.Train(300, 500);
            Console.WriteLine($"Average reward {agent.Test(50, 300, 0)}");

            //while (!env.Step((int)env.ActionSpace.Sample()[0], out var nextState, out var reward))
            //{
            //    env.Render();
            //    Thread.Sleep(50);
            //}

            return;
        }
Beispiel #2
0
        static void Main(string[] args)
        {
            Env env = new LunarLanderEnv();

            var memory = new PriorityExperienceReplay(100000);

            var qFunc = new DQN(env.ObservationSpace.Shape, env.ActionSpace.NumberOfValues(), new[] { 256, 128 }, 0.0001f, 0.999f, 32, memory)
            {
                MemoryInterval            = 1,
                EnableDoubleDQN           = true,
                TargetModelUpdateInterval = 4000,
                TrainingEpochs            = 1
            };

            Agent agent = new AgentQL("dqn_lunarlander", env, qFunc)
            {
                WarmupSteps         = 5000,
                MaxEpsilon          = 1.0f,
                MinEpsilon          = 0.01f,
                EpsilonDecay        = 0.995f,
                TrainInterval       = 1,
                RewardClipping      = false,
                TrainRenderInterval = 10,
                Verbose             = true,
                RenderFreq          = 80,
            };

            agent.Train(1500, 1500);
            //agent.Load($"{agent.Name}_1500");
            agent.Test(100, 400, 2);
        }
Beispiel #3
0
 public DQN(DQN other) : this(rysyPINVOKE.new_DQN__SWIG_5(DQN.getCPtr(other)), true)
 {
     if (rysyPINVOKE.SWIGPendingException.Pending)
     {
         throw rysyPINVOKE.SWIGPendingException.Retrieve();
     }
 }
        public void ClassificationByDQN()
        {
            double _loss = 1.0;
            //
            GRasterLayer featureLayer = new GRasterLayer(featureFullFilename);
            GRasterLayer labelLayer   = new GRasterLayer(trainFullFilename);
            //create environment for agent exploring
            IEnv env = new ImageClassifyEnv(featureLayer, labelLayer);
            //create dqn alogrithm
            DQN dqn = new DQN(env);

            //in order to do this quickly, we set training epochs equals 10.
            //please do not use so few training steps in actual use.
            dqn.SetParameters(10, 0);
            //register event to get information while training
            dqn.OnLearningLossEventHandler += (double loss, double totalReward, double accuracy, double progress, string epochesTime) => { _loss = loss; };
            //start dqn alogrithm learning
            dqn.Learn();
            //in general, loss is less than 1
            Assert.IsTrue(_loss < 1.0);
            //apply dqn to classify fetureLayer
            //pick value
            IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();

            pRasterLayerCursorTool.Visit(featureLayer);
            //
            double[] state         = pRasterLayerCursorTool.PickNormalValue(50, 50);
            double[] action        = dqn.ChooseAction(state).action;
            int      landCoverType = dqn.ActionToRawValue(NP.Argmax(action));

            //do something as you need. i.e. draw landCoverType to bitmap at position ( i , j )
            //the classification results are not stable because of the training epochs are too few.
            Assert.IsTrue(landCoverType >= 0);
        }
Beispiel #5
0
 private void BLearning()
 {
     while (true)
     {
         if (Historical.Count < 20000)
         {
             //
             Thread.Sleep(TimeSpan.FromMinutes(30));
         }
         var correct = 0.0;
         var total   = 0.0;
         var options = new AgentOptions
         {
             Gamma                 = Tembo.Random(0.01, 0.99),
             Epsilon               = Tembo.Random(0.01, 0.75),
             Alpha                 = Tembo.Random(0.01, 0.99),
             ExperinceAddEvery     = Tembo.RandomInt(1, 10000),
             ExperienceSize        = 0,
             LearningSteps         = Tembo.RandomInt(1, 10),
             HiddenUnits           = Tembo.RandomInt(100000, 100000000),
             ErrorClamp            = Tembo.Random(0.01, 1.0),
             AdaptiveLearningSteps = true
         };
         var agent = new DQN(dqnAgent.NumberOfStates, dqnAgent.NumberOfActions, options);
         for (var i = 0; i < Historical.Count; i++)
         {
             var spi    = Historical.ElementAt(i);
             var action = agent.Act(spi.Value.Values);
             if (action == spi.Value.Output)
             {
                 correct += 1;
                 agent.Learn(1);
             }
             else
             {
                 agent.Learn(-1);
             }
             total += 1;
         }
         var winrate = (correct / total) * 100;
         if (winrate > WinRate)
         {
             CN.Log($"NEW AGENT DISCOVERED --> WINRATE {winrate.ToString("p")}, CLASS: {AgentName}", 2);
             Save();
             dqnAgent = agent;
             WinRate  = winrate;
         }
     }
 }
Beispiel #6
0
 public TemboDQN(string agentName, int ns, int na, AgentOptions options, bool backgroundLearning = false)
 {
     AgentName          = agentName;
     BackgroundLearning = backgroundLearning;
     dqnAgent           = new DQN(ns, na, options);
     Candlesticks       = new List <Candlestick>();
     Historical         = new Dictionary <string, State>();
     ForeverLearner     = new Thread(BLearning)
     {
         IsBackground = true,
         Priority     = ThreadPriority.BelowNormal//just incase this agent collective is plugged to a bigger collective running threads of a higher priority
     };
     if (BackgroundLearning)
     {
         ForeverLearner.Start();
     }
 }
 /// <summary>
 /// DQN classify task
 /// </summary>
 /// <param name="featureRasterLayer"></param>
 /// <param name="labelRasterLayer"></param>
 /// <param name="epochs"></param>
 public JobDQNClassify(GRasterLayer featureRasterLayer, GRasterLayer labelRasterLayer, int epochs = 3000)
 {
     _t = new Thread(() =>
     {
         ImageClassifyEnv env = new ImageClassifyEnv(featureRasterLayer, labelRasterLayer);
         _dqn = new DQN(env);
         _dqn.SetParameters(epochs: epochs, gamma: _gamma);
         _dqn.OnLearningLossEventHandler += _dqn_OnLearningLossEventHandler;
         //training
         Summary = "模型训练中";
         _dqn.Learn();
         //classification
         Summary = "分类应用中";
         IRasterLayerCursorTool pRasterLayerCursorTool = new GRasterLayerCursorTool();
         pRasterLayerCursorTool.Visit(featureRasterLayer);
         Bitmap classificationBitmap = new Bitmap(featureRasterLayer.XSize, featureRasterLayer.YSize);
         Graphics g      = Graphics.FromImage(classificationBitmap);
         int seed        = 0;
         int totalPixels = featureRasterLayer.XSize * featureRasterLayer.YSize;
         for (int i = 0; i < featureRasterLayer.XSize; i++)
         {
             for (int j = 0; j < featureRasterLayer.YSize; j++)
             {
                 //get normalized input raw value
                 double[] normal = pRasterLayerCursorTool.PickNormalValue(i, j);
                 var(action, q)  = _dqn.ChooseAction(normal);
                 //convert action to raw byte value
                 int gray         = _dqn.ActionToRawValue(NP.Argmax(action));
                 Color c          = Color.FromArgb(gray, gray, gray);
                 Pen p            = new Pen(c);
                 SolidBrush brush = new SolidBrush(c);
                 g.FillRectangle(brush, new Rectangle(i, j, 1, 1));
                 //report progress
                 Process = (double)(seed++) / totalPixels;
             }
         }
         //save result
         string fullFileName = Directory.GetCurrentDirectory() + @"\tmp\" + DateTime.Now.ToFileTimeUtc() + ".png";
         classificationBitmap.Save(fullFileName);
         //complete
         Summary  = "DQN训练分类完成";
         Complete = true;
         OnTaskComplete?.Invoke(Name, fullFileName);
     });
 }
Beispiel #8
0
 internal static global::System.Runtime.InteropServices.HandleRef getCPtr(DQN obj)
 {
     return((obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr);
 }