Exemple #1
0
        public static GridWorld RandomState()
        {
            var w = new GridWorld();

            w.WorldState = new int[GridSize, GridSize, GridDepth];

            Location wall;
            Location pit;
            Location goal;
            Location player;

            var invalid = false;

            // Keep trying to get 4 unique random positions...
            do
            {
                wall   = Util.RandLoc(0, GridWorld.GridSize);
                pit    = Util.RandLoc(0, GridWorld.GridSize);
                goal   = Util.RandLoc(0, GridWorld.GridSize);
                player = Util.RandLoc(0, GridWorld.GridSize);

                invalid = (wall == pit || wall == goal || wall == player ||
                           pit == goal || pit == player ||
                           goal == player);
            } while (invalid);

            w.WorldState[wall.X, wall.Y, WallLayer] = 1;
            w.WorldState[pit.X, pit.Y, PitLayer]    = 1;
            w.WorldState[goal.X, goal.Y, GoalLayer] = 1;
            w.PlayerLocation = player;
            w.WorldState[player.X, player.Y, PlayerLayer] = 1;

            return(w);
        }
Exemple #2
0
        public static GridWorld StandardState()
        {
            var w = new GridWorld();

            w.WorldState     = new int[GridSize, GridSize, GridDepth];
            w.PlayerLocation = new Location(0, 1);
            w.WorldState[0, 1, PlayerLayer] = 1;
            w.WorldState[2, 2, WallLayer]   = 1;
            w.WorldState[1, 1, PitLayer]    = 1;
            w.WorldState[3, 3, GoalLayer]   = 1;

            return(w);
        }
Exemple #3
0
        public Brain(int numInputs, int numActions)
        {
            CreatedDate  = DateTime.Now;
            TrainingTime = new TimeSpan();

            NumInputs  = numInputs;
            NumActions = numActions;

            // network
            var layer1N = (numInputs + numActions) / 2;

            Net = new Net();
            Net.AddLayer(new InputLayer(1, 1, numInputs));
            Net.AddLayer(new FullyConnLayer(layer1N));
            Net.AddLayer(new ReluLayer());
            Net.AddLayer(new FullyConnLayer(numActions));
            Net.AddLayer(new RegressionLayer());

            World = GridWorld.StandardState();
        }
Exemple #4
0
        public static GridWorld RandomPlayerState()
        {
            var w = new GridWorld();

            w.WorldState = new int[GridSize, GridSize, GridDepth];

            w.WorldState[2, 2, WallLayer] = 1;
            w.WorldState[1, 1, PitLayer]  = 1;
            w.WorldState[3, 3, GoalLayer] = 1;

            Location p;
            bool     invalid = false;

            do
            {
                p       = Util.RandLoc(0, GridWorld.GridSize);
                invalid = (w.WorldState[p.X, p.Y, WallLayer] != 0 || w.WorldState[p.X, p.Y, PitLayer] != 0 || w.WorldState[p.X, p.Y, GoalLayer] != 0);
            } while (invalid);

            w.PlayerLocation = p;
            w.WorldState[p.X, p.Y, PlayerLayer] = 1;

            return(w);
        }
Exemple #5
0
        public void TrainWithExperienceReplay(int numGames, int batchSize, float initialRandomChance, bool degradeRandomChance = true, string saveToFile = null)
        {
            var gamma  = 0.975f;
            var buffer = batchSize * 2;
            var h      = 0;

            //# Stores tuples of (S, A, R, S')
            var replay = new List <object[]>();

            _trainer = new SgdTrainer(Net)
            {
                LearningRate = 0.01, Momentum = 0.0, BatchSize = batchSize, L2Decay = 0.001
            };

            var startTime = DateTime.Now;
            var batches   = 0;

            for (var i = 0; i < numGames; i++)
            {
                World = GridWorld.RandomPlayerState();
                var gameMoves = 0;

                double updatedReward;
                var    gameRunning = true;
                while (gameRunning)
                {
                    //# We are in state S
                    //# Let's run our Q function on S to get Q values for all possible actions
                    var state  = GetInputs();
                    var qVal   = Net.Forward(state);
                    var action = 0;

                    if (Util.Rnd.NextDouble() < initialRandomChance)
                    {
                        //# Choose random action
                        action = Util.Rnd.Next(NumActions);
                    }
                    else
                    {
                        //# Choose best action from Q(s,a) values
                        action = MaxValueIndex(qVal);
                    }

                    //# Take action, observe new state S'
                    World.MovePlayer(action);
                    gameMoves++;
                    TotalTrainingMoves++;
                    var newState = GetInputs();

                    //# Observe reward, limit turns
                    var reward = World.GetReward();
                    gameRunning = !World.GameOver();

                    //# Experience replay storage
                    if (replay.Count < buffer)
                    {
                        replay.Add(new[] { state, (object)action, (object)reward, newState });
                    }
                    else
                    {
                        h         = (h < buffer - 1) ? h + 1 : 0;
                        replay[h] = new[] { state, (object)action, (object)reward, newState };
                        batches++;
                        var batchInputValues  = new Volume[batchSize];
                        var batchOutputValues = new List <double>();

                        //# Randomly sample our experience replay memory
                        for (var b = 0; b < batchSize; b++)
                        {
                            var memory      = replay[Util.Rnd.Next(buffer)];
                            var oldState    = (Volume)memory[0];
                            var oldAction   = (int)memory[1];
                            var oldReward   = (int)memory[2];
                            var oldNewState = (Volume)memory[3];

                            //# Get max_Q(S',a)
                            var newQ = Net.Forward(oldNewState);
                            var y    = GetValues(newQ);
                            var maxQ = MaxValue(newQ);

                            if (oldReward == GridWorld.ProgressScore)
                            {
                                //# Non-terminal state
                                updatedReward = (oldReward + (gamma * maxQ));
                            }
                            else
                            {
                                //# Terminal state
                                updatedReward = oldReward;
                            }

                            //# Target output
                            y[action] = updatedReward;

                            //# Store batched states
                            batchInputValues[b] = oldState;
                            batchOutputValues.AddRange(y);
                        }
                        Console.Write(".");

                        //# Train in batches with multiple scores and actions
                        _trainer.Train(batchOutputValues.ToArray(), batchInputValues);
                        TotalLoss += _trainer.Loss;
                    }
                }
                Console.WriteLine($"{(World.GetReward() == GridWorld.WinScore ? " WON!" : string.Empty)}");
                Console.Write($"Game: {i + 1}");
                TotalTrainingGames++;

                // Save every 10 games...
                if (!string.IsNullOrEmpty(saveToFile) && (i % 10 == 0))
                {
                    Util.SaveBrainToFile(this, saveToFile);
                }

                //# Optinoally: slowly reduce the chance of choosing a random action
                if (degradeRandomChance && initialRandomChance > 0.05f)
                {
                    initialRandomChance -= (1f / numGames);
                }
            }
            var duration = (DateTime.Now - startTime);

            LastLoss      = _trainer.Loss;
            TrainingTime += duration;

            if (!string.IsNullOrEmpty(saveToFile))
            {
                Util.SaveBrainToFile(this, saveToFile);
            }

            Console.WriteLine($"\nAvg loss: {TotalLoss / TotalTrainingMoves}. Last: {LastLoss}");
            Console.WriteLine($"Training duration: {duration}. Total: {TrainingTime}");
        }
Exemple #6
0
        public void Train(int numGames, float initialRandomChance)
        {
            var gamma = 0.9f;

            _trainer = new SgdTrainer(Net)
            {
                LearningRate = 0.01, Momentum = 0.0, BatchSize = 1, L2Decay = 0.001
            };
            var startTime = DateTime.Now;

            for (var i = 0; i < numGames; i++)
            {
                World = GridWorld.StandardState();

                double updatedReward;
                var    gameRunning = true;
                var    gameMoves   = 0;
                while (gameRunning)
                {
                    //# We are in state S
                    //# Let's run our Q function on S to get Q values for all possible actions
                    var state  = GetInputs();
                    var qVal   = Net.Forward(state);
                    var action = 0;

                    if (Util.Rnd.NextDouble() < initialRandomChance)
                    {
                        //# Choose random action
                        action = Util.Rnd.Next(NumActions);
                    }
                    else
                    {
                        //# Choose best action from Q(s,a) values
                        action = MaxValueIndex(qVal);
                    }

                    //# Take action, observe new state S'
                    World.MovePlayer(action);
                    gameMoves++;
                    TotalTrainingMoves++;
                    var newState = GetInputs();

                    //# Observe reward
                    var reward = World.GetReward();
                    gameRunning = !World.GameOver();

                    //# Get max_Q(S',a)
                    var newQ = Net.Forward(newState);
                    var y    = GetValues(newQ);
                    var maxQ = MaxValue(newQ);

                    if (gameRunning)
                    {
                        //# Non-terminal state
                        updatedReward = (reward + (gamma * maxQ));
                    }
                    else
                    {
                        //# Terminal state
                        updatedReward = reward;
                        TotalTrainingGames++;
                        Console.WriteLine($"Game: {TotalTrainingGames}. Moves: {gameMoves}. {(reward == 10 ? "WIN!" : "")}");
                    }

                    //# Target output
                    y[action] = updatedReward;

                    //# Feedback what the score would be for this action
                    _trainer.Train(state, y);
                    TotalLoss += _trainer.Loss;
                }

                //# Slowly reduce the chance of choosing a random action
                if (initialRandomChance > 0.05f)
                {
                    initialRandomChance -= (1f / numGames);
                }
            }
            var duration = (DateTime.Now - startTime);

            LastLoss      = _trainer.Loss;
            TrainingTime += duration;

            Console.WriteLine($"Avg loss: {TotalLoss / TotalTrainingMoves}. Last: {LastLoss}");
            Console.WriteLine($"Training duration: {duration}. Total: {TrainingTime}");
        }
Exemple #7
0
        static void Main(string[] args)
        {
            Console.WriteLine(" ----------------------- ");
            Console.WriteLine("| G.R.I.D --- W.O.R.L.D |");
            Console.WriteLine(" ----------------------- ");
            Console.WriteLine("Tutorial: http://outlace.com/Reinforcement-Learning-Part-3/\n");

            Brain brain;

            if (File.Exists(BrainFile))
            {
                brain = Util.ReadBrainFromFile(BrainFile);

                Console.WriteLine("Brain loaded...");
                Console.WriteLine($"Created: {brain.CreatedDate}. Training Time: {brain.TrainingTime} ({brain.TotalTrainingGames} games)");
                Console.WriteLine($"Avg loss: {brain.TotalLoss / brain.TotalTrainingMoves}. Last: {brain.LastLoss}");
            }
            else
            {
                var numInputs  = GridWorld.GridSize * GridWorld.GridSize * GridWorld.GridDepth;
                var numActions = 4;
                brain = new Brain(numInputs, numActions);
            }

            // Initial output:
            var initialOutput = brain.DisplayOutput(brain.GetInputs());

            //Console.WriteLine("Training...");
            //brain.Train(1000, 1f);

            Console.WriteLine("Batch Training...");
            brain.TrainWithExperienceReplay(3000, 32, 1f, true, BrainFile);

            // Sample output:
            brain.World = GridWorld.RandomPlayerState();
            var trainedOutput = brain.DisplayOutput(brain.GetInputs());

            // Show results:
            Console.WriteLine(brain.World.DisplayGrid());
            Console.WriteLine($"Actions: ({_actionNames[0]} {_actionNames[1]} {_actionNames[2]} {_actionNames[3]})");
            Console.WriteLine($"Initial output: {initialOutput}");
            Console.WriteLine($"Sample output: {trainedOutput}");

            Console.WriteLine("\nBrain saved...\nPress enter to play some games...");
            Console.ReadLine();

            // Play some games:
            do
            {
                Console.Clear();
                brain.World = GridWorld.RandomPlayerState();
                Console.WriteLine("Initial state:");
                Console.WriteLine(brain.World.DisplayGrid());

                var moves = 0;
                while (!brain.World.GameOver())
                {
                    var action = brain.GetNextAction();

                    Console.WriteLine($"\nMove: {++moves}. Taking action: {_actionNames[action]}");
                    brain.World.MovePlayer(action);
                    Console.WriteLine(brain.World.DisplayGrid());
                }

                if (moves >= 10)
                {
                    Console.WriteLine($"Game Over. Too many moves!");
                }
                else
                {
                    Console.WriteLine($"Game {(brain.World.GetReward() == GridWorld.WinScore ? "WON!" : "LOST! :(")}");
                }

                Console.WriteLine("\nPress enter to play another game...");
                Console.ReadLine();
            } while (true);
        }