public static void Run() { const string NetName = "net.dat"; var random = new Random(); var entryContainer = new EntryContainer(); #region Load Net var convInputWith = 11; // Will extract 11x11 area if (convInputWith % 2 == 0) { throw new ArgumentException("convInputWith must be odd"); } // Load IA or initialize new network if not found - Direction choice INet singleNet = NetExtension.LoadOrCreateNet(NetName, () => { var net = FluentNet.Create(convInputWith, convInputWith, 3) .Conv(3, 3, 16).Stride(2) .Tanh() .Conv(2, 2, 16) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); #endregion #region Load data var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader int total = hltFiles.Count; Console.WriteLine($"Loading {total} games..."); foreach (var file in hltFiles) { Console.WriteLine(total--); HltReader reader = new HltReader(file); var playerId = -1; var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman")); if (playerToCopy == null) { Console.WriteLine("Player not found"); continue; } playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1; var width = reader.Width; var height = reader.Height; int lastmoveCount = 1; for (var frame = 0; frame < reader.FrameCount - 1; frame++) { var currentFrame = reader.GetFrame(frame); var map = currentFrame.map; var moves = currentFrame.moves; var helper = new Helper(map, (ushort)playerId); bool foundInFrame = false; int moveCount = 0; // moves for (ushort x = 0; x < width; x++) { for (ushort y = 0; y < height; y++) { if (map[x, y].Owner == playerId) { foundInFrame = true; moveCount++; if (random.NextDouble() < 1.0 / lastmoveCount) { var convVolume = map.GetVolume(convInputWith, playerId, x, y); // Input var direction = moves[y][x]; // Output var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode()); entryContainer.Add(entry1); // Data augmentation var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry2); var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry3); var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry4); } } } } lastmoveCount = moveCount; if (!foundInFrame) { // player has died break; } } } var length = entryContainer.Shuffle(); Console.WriteLine(entryContainer.Summary); #endregion #region Training var trainer = new AdamTrainer(singleNet) { BatchSize = 1024, LearningRate = 0.1, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single"); bool save = true; double lastValidationAcc = 0.0; do { for (int i = 0; i < 1000; i++) { if (i > 5) { trainer.L2Decay = 0.001; } Console.WriteLine($"Epoch #{i + 1}"); if (i % 15 == 0) { trainer.LearningRate = Math.Max(trainer.LearningRate / 10.0, 0.00001); } trainingScheme.RunEpoch(); #region Save Nets if (save) { // Save if validation accuracy has improved if (trainingScheme.ValidationAccuracy > lastValidationAcc) { lastValidationAcc = trainingScheme.ValidationAccuracy; singleNet.SaveNet(NetName); } } #endregion if (Console.KeyAvailable) { break; } } } while (!Console.KeyAvailable); #endregion }
public static void Run() { var random = new Random(RandomUtilities.Seed); int normalInputWidth = 19; int earlyInputWidth = 19; int strongInputWidth = 19; string NetName = $"net.dat"; string NetName_early = $"net_early.dat"; string NetName_strong = $"net_strong.dat"; var entryContainer = new EntryContainer(); var entryContainer_early = new EntryContainer(); var entryContainer_strong = new EntryContainer(); #region Load Net INet singleNet = NetExtension.LoadOrCreateNet(NetName, () => { var net = FluentNet.Create(normalInputWidth, normalInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); INet singleNet_early = NetExtension.LoadOrCreateNet(NetName_early, () => { var net = FluentNet.Create(earlyInputWidth, earlyInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); INet singleNet_strong = NetExtension.LoadOrCreateNet(NetName_strong, () => { var net = FluentNet.Create(strongInputWidth, strongInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); #endregion #region Load data var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader int total = hltFiles.Count; Console.WriteLine($"Loading {total} games..."); foreach (var file in hltFiles) { Console.WriteLine(total--); HltReader reader = new HltReader(file); var playerId = -1; var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman")); if (playerToCopy != null) { playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1; } if (playerId != -1) { var width = reader.Width; var height = reader.Height; int lastmoveCount = 1; for (var frame = 0; frame < reader.FrameCount - 1; frame++) { bool earlyGame = lastmoveCount < 25; var currentFrame = reader.GetFrame(frame); var map = currentFrame.map; var moves = currentFrame.moves; var helper = new Helper(map, (ushort)playerId); bool foundInFrame = false; int moveCount = 0; // moves for (ushort x = 0; x < width; x++) { for (ushort y = 0; y < height; y++) { if (map[x, y].Owner == playerId) { bool strong = map[x, y].Strength > 200; foundInFrame = true; moveCount++; if ((earlyGame && random.NextDouble() < 1.5 / lastmoveCount) || (strong && random.NextDouble() < 1.5 / lastmoveCount) || random.NextDouble() < 1.0 / lastmoveCount) { var w = normalInputWidth; var container = entryContainer; if (earlyGame) { w = earlyInputWidth; container = entryContainer_early; } else if (strong) { w = strongInputWidth; container = entryContainer_strong; } var convVolume = map.GetVolume(w, playerId, x, y); var direction = moves[y][x]; var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode()); container.Add(entry1); var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry2); var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry3); var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry4); } } } } lastmoveCount = moveCount; if (!foundInFrame) { // player has died break; } } } else { Console.WriteLine("not found"); } } var length = entryContainer.Shuffle(); Console.WriteLine("normal: " + entryContainer.Summary); length = entryContainer_early.Shuffle(); Console.WriteLine("early: " + entryContainer_early.Summary); length = entryContainer_strong.Shuffle(); Console.WriteLine("strong " + entryContainer_strong.Summary); #endregion #region Training var trainer = new AdamTrainer(singleNet) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single"); var trainer_early = new AdamTrainer(singleNet_early) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme_early = new TrainingScheme(singleNet_early, trainer_early, entryContainer_early, "single_early"); var trainer_strong = new AdamTrainer(singleNet_strong) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme_strong = new TrainingScheme(singleNet_strong, trainer_strong, entryContainer_strong, "single_strong"); bool save = true; double lastValidationAcc = 0.0; double lastValidationAcc_early = 0.0; double lastValidationAcc_strong = 0.0; double lastTrainAcc = 0.0; double lastTrainAcc_early = 0.0; double lastTrainAcc_strong = 0.0; do { var normal = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer.L2Decay = 0.05; } Console.WriteLine($"[normal] Epoch #{i + 1}"); if (i % 50 == 0) { trainer.LearningRate = Math.Max(trainer.LearningRate / 5.0, 0.00001); } trainingScheme.RunEpoch(); #region Save Nets if (save) { if (trainingScheme.ValidationAccuracy > lastValidationAcc) { lastValidationAcc = trainingScheme.ValidationAccuracy; lastTrainAcc = trainingScheme.TrainAccuracy; singleNet.SaveNet(NetName); } } #endregion if (Console.KeyAvailable) { break; } } }); var early = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer_early.L2Decay = 0.05; } Console.WriteLine($"[early] Epoch #{i + 1}"); if (i % 50 == 0) { trainer_early.LearningRate = Math.Max(trainer_early.LearningRate / 5.0, 0.00001); } trainingScheme_early.RunEpoch(); #region Save Nets if (save) { if (trainingScheme_early.ValidationAccuracy > lastValidationAcc_early) { lastValidationAcc_early = trainingScheme_early.ValidationAccuracy; lastTrainAcc_early = trainingScheme_early.TrainAccuracy; singleNet_early.SaveNet(NetName_early); } } #endregion if (Console.KeyAvailable) { break; } } }); var strong = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer_strong.L2Decay = 0.05; } Console.WriteLine($"[strong] Epoch #{i + 1}"); if (i % 50 == 0) { trainer_strong.LearningRate = Math.Max(trainer_strong.LearningRate / 5.0, 0.00001); } trainingScheme_strong.RunEpoch(); #region Save Nets if (save) { if (trainingScheme_strong.ValidationAccuracy > lastValidationAcc_strong) { lastValidationAcc_strong = trainingScheme_strong.ValidationAccuracy; lastTrainAcc_strong = trainingScheme_strong.TrainAccuracy; singleNet_strong.SaveNet(NetName_strong); } } #endregion if (Console.KeyAvailable) { break; } } }); Task.WaitAll(new[] { normal, strong, early }); }while (!Console.KeyAvailable); #endregion }