예제 #1
0
        /// <summary>
        /// Creates the learner
        /// </summary>
        protected override ILearningAlgorithm <BoardState> CreateLearner()
        {
            double alpha       = 0.05;
            double gamma       = 0.1;
            int    stopDecayAt = (int)(0.4 * this.Environment.Config.MaxEpisodes);

            double epsilon         = 0.4;
            var    selectionPolicy = new EGreedy(
                epsilon,
                this.Environment.Config.Random,
                DecayHelpers.ConstantDecay(0, stopDecayAt, epsilon, 0));

            //double tau = 200;
            //var selectionPolicy = new Softmax(
            //	tau,
            //	this.Environment.Config.Random,
            //	DecayHelpers.ConstantDecay(0, stopDecayAt, tau, 0));

            //return QLearning<BoardState>.New(
            //	this.boardSize * this.boardSize,
            //	selectionPolicy,
            //	alpha,
            //	gamma,
            //	this.Environment.Config.Random);
            return(Sarsa <BoardState> .New(
                       this.boardSize *this.boardSize,
                       selectionPolicy,
                       alpha,
                       gamma,
                       this.Environment.Config.Random));
        }
예제 #2
0
        // On "Start" learning button click
        private void startLearningButton_Click(object sender, EventArgs e)
        {
            // get settings
            GetSettings();
            ShowSettings();

            iterationBox.Text = string.Empty;

            // destroy algorithms
            qLearning = null;
            sarsa     = null;

            if (algorithmCombo.SelectedIndex == 0)
            {
                // create new QLearning algorithm's instance
                qLearning    = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
                workerThread = new Thread(new ThreadStart(QLearningThread));
            }
            else
            {
                // create new Sarsa algorithm's instance
                sarsa        = new Sarsa(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
                workerThread = new Thread(new ThreadStart(SarsaThread));
            }

            // disable all settings controls except "Stop" button
            EnableControls(false);

            // run worker thread
            needToStop = false;
            workerThread.Start();
        }
예제 #3
0
    private void StartLearning_OnClick()
    {
        // reset learning class values
        _qLearning      = null;
        _sarsa          = null;
        _qLearning_FDGS = null;

        if (References.LearningAlgorithm.value == 0)
        {
            // create new QLearning algorithm's instance
            _qLearning    = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
            _workerThread = new Thread(new ThreadStart(QLearningThread));
        }
        else if (References.LearningAlgorithm.value == 1)
        {
            // create new Sarsa algorithm's instance
            _sarsa        = new Sarsa(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate)));
            _workerThread = new Thread(new ThreadStart(SarsaThread));
        }
        else
        {
            // init QLearn
            _qLearning_FDGS = new QLearning_FDGS(actions, _agentStopX, _agentStopY, _map, new TabuSearchExploration(actions, new EpsilonGreedyExploration(Convert.ToDouble(explorationRate))));
            _workerThread   = new Thread(new ThreadStart(QLearningThread_FDGS));
        }

        // disable all settings controls except "Stop" button
        References.EnableControls(false);

        // run worker thread
        _needToStop = false;
        _workerThread.Start();

        Debug.Log("Learning started. Please wait until training is finished.");
    }
예제 #4
0
        // On "Stop" button click
        private void stopButton_Click(object sender, EventArgs e)
        {
            if (workerThread != null)
            {
                // stop worker thread
                needToStop = true;
                while (!workerThread.Join(100))
                {
                    Application.DoEvents();
                }
                workerThread = null;
            }

            // reset learning class values
            qLearning = null;
            sarsa     = null;
        }
예제 #5
0
        public static void Main(string[] args)
        {
            System.Console.OutputEncoding = System.Text.Encoding.Unicode;

            System.Console.WriteLine("Choose algorithm. Monte-Carlo or TD?");
            Policy policy;
            string read = System.Console.ReadLine().ToLower();

            if (read.StartsWith("m"))
            {
                Program.FileName = "MC" + Program.FileName;
                policy           = new MonteCarlo();
            }
            else if (read.StartsWith("t"))
            {
                Program.FileName = "TD" + Program.FileName;
                policy           = new Sarsa();
            }
            else if (read.StartsWith("b"))
            {
                Program.FileName = "BS" + Program.FileName;
                policy           = new BackwardSarsa();
            }
            else
            {
                return;
            }
            bool shouldPrint = read.EndsWith("p");

            Policy.Initialize();

            System.Console.WriteLine("Train or Play?");
            read = System.Console.ReadLine();

            if (read.ToLower() == "t")
            {
                Mode = Mode.Train;
                System.DateTime then = System.DateTime.Now;
                int             lastTimeActionUpdated = 0;
                for (int i = 0; i < 100000000; i++)
                {
                    Episode e = new Episode(policy);
                    if (e.Play())
                    {
                        lastTimeActionUpdated = i;
                    }
                    else if (i - lastTimeActionUpdated > Threshold)
                    {
                        System.Console.WriteLine("No need for more episodes, policy is optimal. Number of episode when action was updated last time: "
                                                 + lastTimeActionUpdated);
                        break;
                    }
                }
                System.Console.WriteLine((System.DateTime.Now - then).TotalSeconds);
            }
            else if (read.ToLower() == "p")
            {
                Mode = Mode.Play;
                while (read.ToLower() == "p")
                {
                    Episode e = new Episode(policy);
                    e.Play();
                    e.Print();
                    System.Console.WriteLine("Train or Play?");
                    read = System.Console.ReadLine();
                }
            }

            if (shouldPrint)
            {
                policy.Print();
            }
            Policy.FlushToDisk();
            System.Console.WriteLine("Total reward: " + reward);
        }
예제 #6
0
        public void learn_test()
        {
            #region doc_main
            // Fix the random number generator
            Accord.Math.Random.Generator.Seed = 0;

            // In this example, we will be using the Sarsa algorithm
            // to make a robot learn how to navigate a map. The map
            // is shown below, where a 1 denotes a wall and 0 denotes
            // areas where the robot can navigate:
            //
            int[,] map =
            {
                { 1, 1, 1, 1, 1, 1, 1, 1, 1 },
                { 1, 1, 0, 0, 0, 0, 0, 0, 1 },
                { 1, 1, 0, 0, 0, 1, 1, 0, 1 },
                { 1, 0, 0, 1, 0, 0, 0, 0, 1 },
                { 1, 0, 0, 1, 1, 1, 1, 0, 1 },
                { 1, 0, 0, 1, 1, 0, 0, 0, 1 },
                { 1, 1, 0, 1, 0, 0, 0, 0, 1 },
                { 1, 1, 0, 1, 0, 1, 1, 0, 1 },
                { 1, 1, 1, 1, 1, 1, 1, 1, 1 },
            };

            // Now, we define the initial and target points from which the
            // robot will be spawn and where it should go, respectively:
            int agentStartX = 1;
            int agentStartY = 4;

            int agentStopX = 7;
            int agentStopY = 4;

            // The robot is able to sense the environment though 8 sensors
            // that capture whether the robot is near a wall or not. Based
            // on the robot's current location, the sensors will return an
            // integer number representing which sensors have detected walls

            Func <int, int, int> getState = (int x, int y) =>
            {
                int c1 = (map[y - 1, x - 1] != 0) ? 1 : 0;
                int c2 = (map[y - 1, x + 0] != 0) ? 1 : 0;
                int c3 = (map[y - 1, x + 1] != 0) ? 1 : 0;
                int c4 = (map[y + 0, x + 1] != 0) ? 1 : 0;
                int c5 = (map[y + 1, x + 1] != 0) ? 1 : 0;
                int c6 = (map[y + 1, x + 0] != 0) ? 1 : 0;
                int c7 = (map[y + 1, x - 1] != 0) ? 1 : 0;
                int c8 = (map[y + 0, x - 1] != 0) ? 1 : 0;

                return(c1 | (c2 << 1) | (c3 << 2) | (c4 << 3) | (c5 << 4) | (c6 << 5) | (c7 << 6) | (c8 << 7));
            };

            // The actions are the possible directions the robot can go:
            //
            //   - case 0: go to north (up)
            //   - case 1: go to east (right)
            //   - case 2: go to south (down)
            //   - case 3: go to west (left)
            //

            int    learningIterations = 1000;
            double explorationRate    = 0.5;
            double learningRate       = 0.5;

            double moveReward = 0;
            double wallReward = -1;
            double goalReward = 1;

            // The function below specifies how the robot should perform an action given its
            // current position and an action number. This will cause the robot to update its
            // current X and Y locations given the direction (above) it was instructed to go:
            Func <int, int, int, Tuple <double, int, int> > doAction = (int currentX, int currentY, int action) =>
            {
                // default reward is equal to moving reward
                double reward = moveReward;

                // moving direction
                int dx = 0, dy = 0;

                switch (action)
                {
                case 0:             // go to north (up)
                    dy = -1;
                    break;

                case 1:             // go to east (right)
                    dx = 1;
                    break;

                case 2:             // go to south (down)
                    dy = 1;
                    break;

                case 3:             // go to west (left)
                    dx = -1;
                    break;
                }

                int newX = currentX + dx;
                int newY = currentY + dy;

                // check new agent's coordinates
                if ((map[newY, newX] != 0) || (newX < 0) || (newX >= map.Columns()) || (newY < 0) || (newY >= map.Rows()))
                {
                    // we found a wall or got outside of the world
                    reward = wallReward;
                }
                else
                {
                    currentX = newX;
                    currentY = newY;

                    // check if we found the goal
                    if ((currentX == agentStopX) && (currentY == agentStopY))
                    {
                        reward = goalReward;
                    }
                }

                return(Tuple.Create(reward, currentX, currentY));
            };


            // After defining all those functions, we create a new Sarsa algorithm:
            var explorationPolicy = new EpsilonGreedyExploration(explorationRate);
            var tabuPolicy        = new TabuSearchExploration(4, explorationPolicy);
            var sarsa             = new Sarsa(256, 4, tabuPolicy);

            // curent coordinates of the agent
            int agentCurrentX = -1;
            int agentCurrentY = -1;

            bool needToStop = false;
            int  iteration  = 0;

            // loop
            while ((!needToStop) && (iteration < learningIterations))
            {
                // set exploration rate for this iteration
                explorationPolicy.Epsilon = explorationRate - ((double)iteration / learningIterations) * explorationRate;

                // set learning rate for this iteration
                sarsa.LearningRate = learningRate - ((double)iteration / learningIterations) * learningRate;

                // clear tabu list
                tabuPolicy.ResetTabuList();

                // reset agent's coordinates to the starting position
                agentCurrentX = agentStartX;
                agentCurrentY = agentStartY;

                // steps performed by agent to get to the goal
                int steps = 1;

                // previous state and action
                int previousState  = getState(agentCurrentX, agentCurrentY);
                int previousAction = sarsa.GetAction(previousState);

                // update agent's current position and get his reward
                var    r      = doAction(agentCurrentX, agentCurrentY, previousAction);
                double reward = r.Item1;
                agentCurrentX = r.Item2;
                agentCurrentY = r.Item3;

                while ((!needToStop) && ((agentCurrentX != agentStopX) || (agentCurrentY != agentStopY)))
                {
                    steps++;

                    // set tabu action
                    tabuPolicy.SetTabuAction((previousAction + 2) % 4, 1);

                    // get agent's next state
                    int nextState = getState(agentCurrentX, agentCurrentY);

                    // get agent's next action
                    int nextAction = sarsa.GetAction(nextState);

                    // do learning of the agent - update his Q-function
                    sarsa.UpdateState(previousState, previousAction, reward, nextState, nextAction);

                    // update agent's new position and get his reward
                    r             = doAction(agentCurrentX, agentCurrentY, nextAction);
                    reward        = r.Item1;
                    agentCurrentX = r.Item2;
                    agentCurrentY = r.Item3;

                    previousState  = nextState;
                    previousAction = nextAction;
                }

                if (!needToStop)
                {
                    // update Q-function if terminal state was reached
                    sarsa.UpdateState(previousState, previousAction, reward);
                }

                iteration++;
            }

            // The end position for the robot will be (7, 4):
            int finalPosX = agentCurrentX; // 7
            int finalPosY = agentCurrentY; // 4;
            #endregion

            Assert.AreEqual(7, finalPosX);
            Assert.AreEqual(4, finalPosY);
        }