public static GridWorld RandomState() { var w = new GridWorld(); w.WorldState = new int[GridSize, GridSize, GridDepth]; Location wall; Location pit; Location goal; Location player; var invalid = false; // Keep trying to get 4 unique random positions... do { wall = Util.RandLoc(0, GridWorld.GridSize); pit = Util.RandLoc(0, GridWorld.GridSize); goal = Util.RandLoc(0, GridWorld.GridSize); player = Util.RandLoc(0, GridWorld.GridSize); invalid = (wall == pit || wall == goal || wall == player || pit == goal || pit == player || goal == player); } while (invalid); w.WorldState[wall.X, wall.Y, WallLayer] = 1; w.WorldState[pit.X, pit.Y, PitLayer] = 1; w.WorldState[goal.X, goal.Y, GoalLayer] = 1; w.PlayerLocation = player; w.WorldState[player.X, player.Y, PlayerLayer] = 1; return(w); }
public static GridWorld StandardState() { var w = new GridWorld(); w.WorldState = new int[GridSize, GridSize, GridDepth]; w.PlayerLocation = new Location(0, 1); w.WorldState[0, 1, PlayerLayer] = 1; w.WorldState[2, 2, WallLayer] = 1; w.WorldState[1, 1, PitLayer] = 1; w.WorldState[3, 3, GoalLayer] = 1; return(w); }
public Brain(int numInputs, int numActions) { CreatedDate = DateTime.Now; TrainingTime = new TimeSpan(); NumInputs = numInputs; NumActions = numActions; // network var layer1N = (numInputs + numActions) / 2; Net = new Net(); Net.AddLayer(new InputLayer(1, 1, numInputs)); Net.AddLayer(new FullyConnLayer(layer1N)); Net.AddLayer(new ReluLayer()); Net.AddLayer(new FullyConnLayer(numActions)); Net.AddLayer(new RegressionLayer()); World = GridWorld.StandardState(); }
public static GridWorld RandomPlayerState() { var w = new GridWorld(); w.WorldState = new int[GridSize, GridSize, GridDepth]; w.WorldState[2, 2, WallLayer] = 1; w.WorldState[1, 1, PitLayer] = 1; w.WorldState[3, 3, GoalLayer] = 1; Location p; bool invalid = false; do { p = Util.RandLoc(0, GridWorld.GridSize); invalid = (w.WorldState[p.X, p.Y, WallLayer] != 0 || w.WorldState[p.X, p.Y, PitLayer] != 0 || w.WorldState[p.X, p.Y, GoalLayer] != 0); } while (invalid); w.PlayerLocation = p; w.WorldState[p.X, p.Y, PlayerLayer] = 1; return(w); }
public void TrainWithExperienceReplay(int numGames, int batchSize, float initialRandomChance, bool degradeRandomChance = true, string saveToFile = null) { var gamma = 0.975f; var buffer = batchSize * 2; var h = 0; //# Stores tuples of (S, A, R, S') var replay = new List <object[]>(); _trainer = new SgdTrainer(Net) { LearningRate = 0.01, Momentum = 0.0, BatchSize = batchSize, L2Decay = 0.001 }; var startTime = DateTime.Now; var batches = 0; for (var i = 0; i < numGames; i++) { World = GridWorld.RandomPlayerState(); var gameMoves = 0; double updatedReward; var gameRunning = true; while (gameRunning) { //# We are in state S //# Let's run our Q function on S to get Q values for all possible actions var state = GetInputs(); var qVal = Net.Forward(state); var action = 0; if (Util.Rnd.NextDouble() < initialRandomChance) { //# Choose random action action = Util.Rnd.Next(NumActions); } else { //# Choose best action from Q(s,a) values action = MaxValueIndex(qVal); } //# Take action, observe new state S' World.MovePlayer(action); gameMoves++; TotalTrainingMoves++; var newState = GetInputs(); //# Observe reward, limit turns var reward = World.GetReward(); gameRunning = !World.GameOver(); //# Experience replay storage if (replay.Count < buffer) { replay.Add(new[] { state, (object)action, (object)reward, newState }); } else { h = (h < buffer - 1) ? h + 1 : 0; replay[h] = new[] { state, (object)action, (object)reward, newState }; batches++; var batchInputValues = new Volume[batchSize]; var batchOutputValues = new List <double>(); //# Randomly sample our experience replay memory for (var b = 0; b < batchSize; b++) { var memory = replay[Util.Rnd.Next(buffer)]; var oldState = (Volume)memory[0]; var oldAction = (int)memory[1]; var oldReward = (int)memory[2]; var oldNewState = (Volume)memory[3]; //# Get max_Q(S',a) var newQ = Net.Forward(oldNewState); var y = GetValues(newQ); var maxQ = MaxValue(newQ); if (oldReward == GridWorld.ProgressScore) { //# Non-terminal state updatedReward = (oldReward + (gamma * maxQ)); } else { //# Terminal state updatedReward = oldReward; } //# Target output y[action] = updatedReward; //# Store batched states batchInputValues[b] = oldState; batchOutputValues.AddRange(y); } Console.Write("."); //# Train in batches with multiple scores and actions _trainer.Train(batchOutputValues.ToArray(), batchInputValues); TotalLoss += _trainer.Loss; } } Console.WriteLine($"{(World.GetReward() == GridWorld.WinScore ? " WON!" : string.Empty)}"); Console.Write($"Game: {i + 1}"); TotalTrainingGames++; // Save every 10 games... if (!string.IsNullOrEmpty(saveToFile) && (i % 10 == 0)) { Util.SaveBrainToFile(this, saveToFile); } //# Optinoally: slowly reduce the chance of choosing a random action if (degradeRandomChance && initialRandomChance > 0.05f) { initialRandomChance -= (1f / numGames); } } var duration = (DateTime.Now - startTime); LastLoss = _trainer.Loss; TrainingTime += duration; if (!string.IsNullOrEmpty(saveToFile)) { Util.SaveBrainToFile(this, saveToFile); } Console.WriteLine($"\nAvg loss: {TotalLoss / TotalTrainingMoves}. Last: {LastLoss}"); Console.WriteLine($"Training duration: {duration}. Total: {TrainingTime}"); }
public void Train(int numGames, float initialRandomChance) { var gamma = 0.9f; _trainer = new SgdTrainer(Net) { LearningRate = 0.01, Momentum = 0.0, BatchSize = 1, L2Decay = 0.001 }; var startTime = DateTime.Now; for (var i = 0; i < numGames; i++) { World = GridWorld.StandardState(); double updatedReward; var gameRunning = true; var gameMoves = 0; while (gameRunning) { //# We are in state S //# Let's run our Q function on S to get Q values for all possible actions var state = GetInputs(); var qVal = Net.Forward(state); var action = 0; if (Util.Rnd.NextDouble() < initialRandomChance) { //# Choose random action action = Util.Rnd.Next(NumActions); } else { //# Choose best action from Q(s,a) values action = MaxValueIndex(qVal); } //# Take action, observe new state S' World.MovePlayer(action); gameMoves++; TotalTrainingMoves++; var newState = GetInputs(); //# Observe reward var reward = World.GetReward(); gameRunning = !World.GameOver(); //# Get max_Q(S',a) var newQ = Net.Forward(newState); var y = GetValues(newQ); var maxQ = MaxValue(newQ); if (gameRunning) { //# Non-terminal state updatedReward = (reward + (gamma * maxQ)); } else { //# Terminal state updatedReward = reward; TotalTrainingGames++; Console.WriteLine($"Game: {TotalTrainingGames}. Moves: {gameMoves}. {(reward == 10 ? "WIN!" : "")}"); } //# Target output y[action] = updatedReward; //# Feedback what the score would be for this action _trainer.Train(state, y); TotalLoss += _trainer.Loss; } //# Slowly reduce the chance of choosing a random action if (initialRandomChance > 0.05f) { initialRandomChance -= (1f / numGames); } } var duration = (DateTime.Now - startTime); LastLoss = _trainer.Loss; TrainingTime += duration; Console.WriteLine($"Avg loss: {TotalLoss / TotalTrainingMoves}. Last: {LastLoss}"); Console.WriteLine($"Training duration: {duration}. Total: {TrainingTime}"); }
static void Main(string[] args) { Console.WriteLine(" ----------------------- "); Console.WriteLine("| G.R.I.D --- W.O.R.L.D |"); Console.WriteLine(" ----------------------- "); Console.WriteLine("Tutorial: http://outlace.com/Reinforcement-Learning-Part-3/\n"); Brain brain; if (File.Exists(BrainFile)) { brain = Util.ReadBrainFromFile(BrainFile); Console.WriteLine("Brain loaded..."); Console.WriteLine($"Created: {brain.CreatedDate}. Training Time: {brain.TrainingTime} ({brain.TotalTrainingGames} games)"); Console.WriteLine($"Avg loss: {brain.TotalLoss / brain.TotalTrainingMoves}. Last: {brain.LastLoss}"); } else { var numInputs = GridWorld.GridSize * GridWorld.GridSize * GridWorld.GridDepth; var numActions = 4; brain = new Brain(numInputs, numActions); } // Initial output: var initialOutput = brain.DisplayOutput(brain.GetInputs()); //Console.WriteLine("Training..."); //brain.Train(1000, 1f); Console.WriteLine("Batch Training..."); brain.TrainWithExperienceReplay(3000, 32, 1f, true, BrainFile); // Sample output: brain.World = GridWorld.RandomPlayerState(); var trainedOutput = brain.DisplayOutput(brain.GetInputs()); // Show results: Console.WriteLine(brain.World.DisplayGrid()); Console.WriteLine($"Actions: ({_actionNames[0]} {_actionNames[1]} {_actionNames[2]} {_actionNames[3]})"); Console.WriteLine($"Initial output: {initialOutput}"); Console.WriteLine($"Sample output: {trainedOutput}"); Console.WriteLine("\nBrain saved...\nPress enter to play some games..."); Console.ReadLine(); // Play some games: do { Console.Clear(); brain.World = GridWorld.RandomPlayerState(); Console.WriteLine("Initial state:"); Console.WriteLine(brain.World.DisplayGrid()); var moves = 0; while (!brain.World.GameOver()) { var action = brain.GetNextAction(); Console.WriteLine($"\nMove: {++moves}. Taking action: {_actionNames[action]}"); brain.World.MovePlayer(action); Console.WriteLine(brain.World.DisplayGrid()); } if (moves >= 10) { Console.WriteLine($"Game Over. Too many moves!"); } else { Console.WriteLine($"Game {(brain.World.GetReward() == GridWorld.WinScore ? "WON!" : "LOST! :(")}"); } Console.WriteLine("\nPress enter to play another game..."); Console.ReadLine(); } while (true); }