예제 #1
0
    private void sARSA(State currentState, Strategie strategie)
    {
        firstRound(currentState, strategie);

        //Q(s,a)
        QValue qValue = strategie.getQValue(lastState, action, qValues);

        // a ausführen
        ExecuteAction(qValue.getAction());
        // r beobachten
        reward = getReward();
        //s' beobachten
        currentState = DetermineState();

        //Q(s',a') gemäß Strategie
        QValue currentQValue = strategie.getQValueForStrategie(currentState, qValues);

        //Tabelle aktualisieren
        qValues = SARSA.updateTable(lastState, lastQValue.getAction(), currentState, currentQValue.getAction(), reward, qValues);

        // s<-s', a<-a'
        lastState  = currentState;
        lastQValue = currentQValue;
        action     = currentQValue.getAction();
    }
예제 #2
0
        private static void Main(string[] args)
        {
            /**Region for setting up SARSA function (and possibly parameters)**/

            #region SARSA Setup

            //Set up SARSA object
            var explorationPolicy = new EpsilonGreedyExploration(ExplorationRate);
            var numberOfStates    = 15 * 15 * 15 * 15;
            var numberOfActions   = Enum.GetValues(typeof(Type)).Length;
            var sarsa             = new SARSA(numberOfStates, numberOfActions, explorationPolicy);

            //Prepare the state mapping
            Func <Pokémon, Pokémon, long> getState = (pokémon1, pokémon2) =>
            {
                var moveTypes = pokémon1.Moves.Select(t => t.AttackType).Distinct().ToList();

                return
                    (15 * 15 * 15 * (long)pokémon1.Types[0] +
                     15 * 15 * (long)(pokémon1.Types.Count > 1 ? pokémon1.Types[1] : pokémon1.Types[0]) +
                     15 * (long)pokémon2.Types[0] +
                     1 * (long)(pokémon2.Types.Count > 1 ? pokémon2.Types[1] : pokémon2.Types[0]));
            };

            #endregion SARSA Setup

            using (var sw = new StreamWriter("PineappleExpress.txt"))
            {
                sw.Write("");
            }

            /**Region for setting up the battle itself**/

            #region Battle Execution

            //For the specified number of battles, perform battles and update the policy
            for (var battleNumber = 0; battleNumber < NumberOfBattles; battleNumber++)
            {
                // set exploration rate for this iteration
                explorationPolicy.ExplorationRate =
                    ExplorationRate - (double)battleNumber / NumberOfBattles * ExplorationRate;

                // set learning rate for this iteration
                sarsa.LearningRate = LearningRate - (double)battleNumber / NumberOfBattles * LearningRate;

                //Prepare the Pokémon
                Pokémon pokemon1 = RentalPokémon.RentalPorygon;  //A pre-made Porygon
                Pokémon pokemon2 = RentalPokémon.RentalVenusaur; //A pre-made opponent

                long previousState  = -1;
                var  previousAction = -1;
                long currentState   = -1;
                var  nextAction     = -1;

                var reward    = 0.0;
                var firstTurn = true;

                double percentFinished = 0;

                //Battle loop
                while (!(pokemon1.IsFainted || pokemon2.IsFainted))
                {
                    //Shift states
                    currentState = getState(pokemon1, pokemon2);
                    var validTypes = pokemon1.Moves.Select(m => (int)m.AttackType).Distinct().ToList();
                    nextAction = sarsa.GetAction(currentState, validTypes);

                    //update SARSA
                    if (!firstTurn)
                    {
                        sarsa.UpdateState(previousState, previousAction, reward, currentState, nextAction);
                    }
                    else
                    {
                        firstTurn = false;
                    }

                    //Determine who moves first
                    var firstMover = pokemon1.Stats[Stat.Speed] > pokemon2.Stats[Stat.Speed] ? pokemon1 : pokemon2;

                    //Perform actions
                    if (pokemon1 == firstMover)
                    {
                        reward = pokemon1.UseMoveOfType((Type)nextAction, pokemon2);
                        Console.WriteLine("{0} (Pokémon 1) used a move of type {1}", pokemon1.Species.Name,
                                          Enum.GetName(typeof(Type), (Type)nextAction));
                        Console.WriteLine("Did {0} damage. {1} (Pokémon 2) now has {2} health remaining)",
                                          reward, pokemon2.Species.Name, pokemon2.RemainingHealth);
                        Console.WriteLine(((Type)nextAction).MultiplierOn(pokemon2.Types.ToArray()));
                        if (!pokemon2.IsFainted)
                        {
                            pokemon2.Use(new Random().Next(4), pokemon1);
                        }
                        else
                        {
                            reward += 20;
                        }
                    }
                    else
                    {
                        pokemon2.Use(new Random().Next(4), pokemon1);

                        //Console.WriteLine("{0} (Pokémon 2) used {1}", pokemon2.Species.Name, pokemon2.Moves[0].Name);
                        //Console.WriteLine("Did {0} damage. {1} (Pokémon 1) now has {2} health remaining)",
                        //    reward, pokemon1.Species.Name, pokemon1.RemainingHealth);

                        if (!pokemon1.IsFainted)
                        {
                            reward = pokemon1.UseMoveOfType((Type)nextAction, pokemon2);
                            Console.WriteLine("{0} (Pokémon 1) used a move of type {1}", pokemon1.Species.Name,
                                              Enum.GetName(typeof(Type), (Type)nextAction));
                            Console.WriteLine("Did {0} damage. {1} (Pokémon 2) now has {2} health remaining)",
                                              reward, pokemon2.Species.Name, pokemon2.RemainingHealth);
                            Console.WriteLine(((Type)nextAction).MultiplierOn(pokemon2.Types.ToArray()));
                        }
                    }

                    previousState   = currentState;
                    previousAction  = nextAction;
                    percentFinished = ((double)pokemon2.Stats[Stat.HP] - pokemon2.RemainingHealth) /
                                      pokemon2.Stats[Stat.HP];
                    Console.WriteLine($"{reward}");
                }

                sarsa.UpdateState(previousState, previousAction, reward, currentState, nextAction);

                if (pokemon1.IsFainted)
                {
                    Console.WriteLine("{0} (Pokémon 1) Fainted", pokemon1.Species.Name);
                }
                else
                {
                    Console.WriteLine("{0} (Pokémon 2) Fainted", pokemon2.Species.Name);
                }

                //Print score for graphing
                using (var sw = new StreamWriter($"PineappleExpress({ExplorationRate}_{LearningRate}).txt", true))
                {
                    sw.WriteLine("{0}, {1}", battleNumber, percentFinished);
                }
            }

            #endregion Battle Execution
        }
예제 #3
0
    void Update()
    {
        // Detect if a game or algo are change
        if (selectedGame != oldGame || selectedAlgo != oldAlgo || selectedSokobanLevel != oldSokobanLevel)
        {
            try {
                btnComponent.onClick.RemoveListener(game.TaskOnClick);
            } catch {
                Debug.Log("No listener attach to the play button.");
            }

            // update old values
            oldGame         = selectedGame;
            oldAlgo         = selectedAlgo;
            oldSokobanLevel = selectedSokobanLevel;

            // Select the RL algorithm
            Base algoType;
            if (selectedAlgo == Algo.MarkovPolicy)
            {
                algoType = new MarkovPolicy();
            }
            else if (selectedAlgo == Algo.MarkovValue)
            {
                algoType = new MarkovValue();
            }
            else if (selectedAlgo == Algo.MonteCarlo)
            {
                algoType = new MonteCarlo();
            }
            else if (selectedAlgo == Algo.SARSA)
            {
                algoType = new SARSA();
            }
            else
            {
                algoType = new QLearning();
            }

            // Select the game
            Type gameType;
            if (selectedGame == Game.GridWorld)
            {
                gameType = typeof(GridWorld.GridWorld <>);
            }
            else if (selectedGame == Game.TicTacToe)
            {
                gameType = typeof(TicTacToe.TicTacToe <>);
            }
            else
            {
                gameType = typeof(Sokoban.Sokoban <>);
            }

            // Create the game instance for the RL algorithm
            var type = gameType.MakeGenericType(algoType.GetType());
            game = (IGame)Activator.CreateInstance(type);

            // Clear the screen
            GameObject[] gos = GameObject.FindGameObjectsWithTag("Tile");
            foreach (GameObject go in gos)
            {
                var tile = go.GetComponent <SpriteRenderer>();
                tile.color = new Color(0, 0, 0, 0);
            }
            // Hide the player
            goPlayer.transform.position = new Vector3(50, 50, 0);

            // Start the game
            var watch = System.Diagnostics.Stopwatch.StartNew();
            game.Start();
            watch.Stop();
            elapsedMs = watch.ElapsedMilliseconds;

            // Debug to move the game step by step
            btnComponent.onClick.AddListener(game.TaskOnClick);
        }

        game.Update();
    }