public Brain(int numInputs, int numActions) { CreatedDate = DateTime.Now; TrainingTime = new TimeSpan(); NumInputs = numInputs; NumActions = numActions; // network var layer1N = (numInputs + numActions) / 2; Net = new Net(); Net.AddLayer(new InputLayer(1, 1, numInputs)); Net.AddLayer(new FullyConnLayer(layer1N)); Net.AddLayer(new ReluLayer()); Net.AddLayer(new FullyConnLayer(numActions)); Net.AddLayer(new RegressionLayer()); World = GridWorld.StandardState(); }
public void Train(int numGames, float initialRandomChance) { var gamma = 0.9f; _trainer = new SgdTrainer(Net) { LearningRate = 0.01, Momentum = 0.0, BatchSize = 1, L2Decay = 0.001 }; var startTime = DateTime.Now; for (var i = 0; i < numGames; i++) { World = GridWorld.StandardState(); double updatedReward; var gameRunning = true; var gameMoves = 0; while (gameRunning) { //# We are in state S //# Let's run our Q function on S to get Q values for all possible actions var state = GetInputs(); var qVal = Net.Forward(state); var action = 0; if (Util.Rnd.NextDouble() < initialRandomChance) { //# Choose random action action = Util.Rnd.Next(NumActions); } else { //# Choose best action from Q(s,a) values action = MaxValueIndex(qVal); } //# Take action, observe new state S' World.MovePlayer(action); gameMoves++; TotalTrainingMoves++; var newState = GetInputs(); //# Observe reward var reward = World.GetReward(); gameRunning = !World.GameOver(); //# Get max_Q(S',a) var newQ = Net.Forward(newState); var y = GetValues(newQ); var maxQ = MaxValue(newQ); if (gameRunning) { //# Non-terminal state updatedReward = (reward + (gamma * maxQ)); } else { //# Terminal state updatedReward = reward; TotalTrainingGames++; Console.WriteLine($"Game: {TotalTrainingGames}. Moves: {gameMoves}. {(reward == 10 ? "WIN!" : "")}"); } //# Target output y[action] = updatedReward; //# Feedback what the score would be for this action _trainer.Train(state, y); TotalLoss += _trainer.Loss; } //# Slowly reduce the chance of choosing a random action if (initialRandomChance > 0.05f) { initialRandomChance -= (1f / numGames); } } var duration = (DateTime.Now - startTime); LastLoss = _trainer.Loss; TrainingTime += duration; Console.WriteLine($"Avg loss: {TotalLoss / TotalTrainingMoves}. Last: {LastLoss}"); Console.WriteLine($"Training duration: {duration}. Total: {TrainingTime}"); }