private INet <double> BuildModel(GameModelConfig cfg) { _logger.LogInformation("Building the model."); var model = FluentNet <double> .Create(cfg.InputWidth, cfg.InputHeight, cfg.ImageChannels) .Conv(8, 8, 32).Stride(4).Pad(2) .Pool(2, 2) .Relu() .Conv(4, 4, 64).Stride(2).Pad(2) .Pool(2, 2) .Relu() .Conv(3, 3, 64).Stride(1).Pad(2) .Pool(2, 2) .Relu() .FullyConn(512) .Relu() .FullyConn(cfg.Actions) .Softmax(cfg.Actions) .Build(); //create model file if not present if (!File.Exists(modelPath)) { File.WriteAllText(modelPath, model.ToJson()); } _logger.LogInformation("Finished bilding the model."); return(model); }
public void CreateTest() { var net = FluentNet.Create(10, 10, 2) .Relu() .FullyConn(10) .Softmax(10) .Build(); //net.Forward(new Volume(10, 10, 2)); }
private void MnistDemo() { Directory.CreateDirectory(mnistFolder); string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile); string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile); string testingLabelFilePath = Path.Combine(mnistFolder, testingLabelFile); string testingImageFilePath = Path.Combine(mnistFolder, testingImageFile); // Download Mnist files if needed Console.WriteLine("Downloading Mnist training files..."); DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath); DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath); Console.WriteLine("Downloading Mnist testing files..."); DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath); DownloadFile(urlMnist + testingImageFile, testingImageFilePath); // Load data Console.WriteLine("Loading the datasets..."); this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath); this.testing = MnistReader.Load(testingLabelFilePath, testingImageFilePath); if (this.training.Count == 0 || this.testing.Count == 0) { Console.WriteLine("Missing Mnist training/testing files."); Console.ReadKey(); return; } // Create network this.net = FluentNet.Create(24, 24, 1) .Conv(5, 5, 8).Stride(1).Pad(2) .Relu() .Pool(2, 2).Stride(2) .Conv(5, 5, 16).Stride(1).Pad(2) .Relu() .Pool(3, 3).Stride(3) .FullyConn(10) .Softmax(10) .Build(); this.trainer = new AdadeltaTrainer(this.net) { BatchSize = 20, L2Decay = 0.001, }; Console.WriteLine("Convolutional neural network learning...[Press any key to stop]"); do { var sample = this.SampleTrainingInstance(); this.Step(sample); } while (!Console.KeyAvailable); }
private FluentNet<double> BuildCNN() => FluentNet<double>.Create(80, 60, 1) .Conv(5, 5, 7) .Relu() .Pool(2, 2) .Conv(5, 5, 5) .Relu() .Pool(2, 2) .FullyConn(40) .Tanh() .FullyConn(4) .Softmax(4) .Build();
public void MergeTest() { var branch1 = FluentNet.Create(10, 10, 2) .Relu() .FullyConn(10); var branch2 = FluentNet.Create(10, 10, 2) .Relu() .FullyConn(20); var net = FluentNet.Merge(branch1, branch2) .FullyConn(5) .Softmax(5) .Build(); net.Forward(new[] { new Volume(10, 10, 2), new Volume(10, 10, 2) }); }
private void MnistDemo() { var datasets = new DataSets(); if (!datasets.Load(100)) { return; } // Create network this._net = FluentNet <double> .Create(24, 24, 1) .Conv(5, 5, 8).Stride(1).Pad(2) .Relu() .Pool(2, 2).Stride(2) .Conv(5, 5, 16).Stride(1).Pad(2) .Relu() .Pool(3, 3).Stride(3) .FullyConn(10) .Softmax(10) .Build(); this._trainer = new SgdTrainer <double>(this._net) { LearningRate = 0.01, BatchSize = 20, L2Decay = 0.001, Momentum = 0.9 }; Console.WriteLine("Convolutional neural network learning...[Press any key to stop]"); do { var trainSample = datasets.Train.NextBatch(this._trainer.BatchSize); Train(trainSample.Item1, trainSample.Item2, trainSample.Item3); var testSample = datasets.Test.NextBatch(this._trainer.BatchSize); Test(testSample.Item1, testSample.Item3, this._testAccWindow); Console.WriteLine("Loss: {0} Train accuracy: {1}% Test accuracy: {2}%", this._trainer.Loss, Math.Round(this._trainAccWindow.Items.Average() * 100.0, 2), Math.Round(this._testAccWindow.Items.Average() * 100.0, 2)); Console.WriteLine("Example seen: {0} Fwd: {1}ms Bckw: {2}ms", this._stepCount, Math.Round(this._trainer.ForwardTimeMs, 2), Math.Round(this._trainer.BackwardTimeMs, 2)); } while (!Console.KeyAvailable); }
public void FluentNetSerialization() { // Fluent version var net = FluentNet <double> .Create(24, 24, 1) .Conv(5, 5, 8).Stride(1).Pad(2) .Relu() .Pool(2, 2).Stride(2) .Conv(5, 5, 16).Stride(1).Pad(2) .Relu() .Pool(3, 3).Stride(3) .FullyConn(10) .Softmax(10) .Build(); var json = net.ToJson(); var deserialized = SerializationExtensions.FromJson <double>(json); Assert.AreEqual(9, deserialized.Layers.Count); }
public MainWindow() { InitializeComponent(); Init(); InitPalette(); StartNes(); var net = FluentNet <double> .Create(WIDTH, HEIGHT, 4) .Conv(5, 5, 8).Stride(1).Pad(2) .Relu() .Pool(2, 2).Stride(2) .Conv(5, 5, 16).Stride(1).Pad(2) .Relu() .Pool(3, 3).Stride(3) .FullyConn(10) .Softmax(10) .Build(); }
public void FluentBinaryNetSerializerTest() { var net = FluentNet.Create(5, 5, 3) .Conv(2, 2, 16) .FullyConn(3) .Softmax(3) .Build(); // Serialize (binary) using (var ms = new MemoryStream()) { net.SaveBinary(ms); ms.Position = 0; // Deserialize (binary) FluentNet deserialized = SerializationExtensions.LoadBinary(ms) as FluentNet; Assert.IsNotNull(deserialized); Assert.AreEqual(net.InputLayers.Count, deserialized.InputLayers.Count); // TODO: improve test } }
public static void Run() { const string NetName = "net.dat"; var random = new Random(); var entryContainer = new EntryContainer(); #region Load Net var convInputWith = 11; // Will extract 11x11 area if (convInputWith % 2 == 0) { throw new ArgumentException("convInputWith must be odd"); } // Load IA or initialize new network if not found - Direction choice INet singleNet = NetExtension.LoadOrCreateNet(NetName, () => { var net = FluentNet.Create(convInputWith, convInputWith, 3) .Conv(3, 3, 16).Stride(2) .Tanh() .Conv(2, 2, 16) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); #endregion #region Load data var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader int total = hltFiles.Count; Console.WriteLine($"Loading {total} games..."); foreach (var file in hltFiles) { Console.WriteLine(total--); HltReader reader = new HltReader(file); var playerId = -1; var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman")); if (playerToCopy == null) { Console.WriteLine("Player not found"); continue; } playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1; var width = reader.Width; var height = reader.Height; int lastmoveCount = 1; for (var frame = 0; frame < reader.FrameCount - 1; frame++) { var currentFrame = reader.GetFrame(frame); var map = currentFrame.map; var moves = currentFrame.moves; var helper = new Helper(map, (ushort)playerId); bool foundInFrame = false; int moveCount = 0; // moves for (ushort x = 0; x < width; x++) { for (ushort y = 0; y < height; y++) { if (map[x, y].Owner == playerId) { foundInFrame = true; moveCount++; if (random.NextDouble() < 1.0 / lastmoveCount) { var convVolume = map.GetVolume(convInputWith, playerId, x, y); // Input var direction = moves[y][x]; // Output var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode()); entryContainer.Add(entry1); // Data augmentation var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry2); var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry3); var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry4); } } } } lastmoveCount = moveCount; if (!foundInFrame) { // player has died break; } } } var length = entryContainer.Shuffle(); Console.WriteLine(entryContainer.Summary); #endregion #region Training var trainer = new AdamTrainer(singleNet) { BatchSize = 1024, LearningRate = 0.1, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single"); bool save = true; double lastValidationAcc = 0.0; do { for (int i = 0; i < 1000; i++) { if (i > 5) { trainer.L2Decay = 0.001; } Console.WriteLine($"Epoch #{i + 1}"); if (i % 15 == 0) { trainer.LearningRate = Math.Max(trainer.LearningRate / 10.0, 0.00001); } trainingScheme.RunEpoch(); #region Save Nets if (save) { // Save if validation accuracy has improved if (trainingScheme.ValidationAccuracy > lastValidationAcc) { lastValidationAcc = trainingScheme.ValidationAccuracy; singleNet.SaveNet(NetName); } } #endregion if (Console.KeyAvailable) { break; } } } while (!Console.KeyAvailable); #endregion }
public static void Run() { var random = new Random(RandomUtilities.Seed); int normalInputWidth = 19; int earlyInputWidth = 19; int strongInputWidth = 19; string NetName = $"net.dat"; string NetName_early = $"net_early.dat"; string NetName_strong = $"net_strong.dat"; var entryContainer = new EntryContainer(); var entryContainer_early = new EntryContainer(); var entryContainer_strong = new EntryContainer(); #region Load Net INet singleNet = NetExtension.LoadOrCreateNet(NetName, () => { var net = FluentNet.Create(normalInputWidth, normalInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); INet singleNet_early = NetExtension.LoadOrCreateNet(NetName_early, () => { var net = FluentNet.Create(earlyInputWidth, earlyInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); INet singleNet_strong = NetExtension.LoadOrCreateNet(NetName_strong, () => { var net = FluentNet.Create(strongInputWidth, strongInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); #endregion #region Load data var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader int total = hltFiles.Count; Console.WriteLine($"Loading {total} games..."); foreach (var file in hltFiles) { Console.WriteLine(total--); HltReader reader = new HltReader(file); var playerId = -1; var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman")); if (playerToCopy != null) { playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1; } if (playerId != -1) { var width = reader.Width; var height = reader.Height; int lastmoveCount = 1; for (var frame = 0; frame < reader.FrameCount - 1; frame++) { bool earlyGame = lastmoveCount < 25; var currentFrame = reader.GetFrame(frame); var map = currentFrame.map; var moves = currentFrame.moves; var helper = new Helper(map, (ushort)playerId); bool foundInFrame = false; int moveCount = 0; // moves for (ushort x = 0; x < width; x++) { for (ushort y = 0; y < height; y++) { if (map[x, y].Owner == playerId) { bool strong = map[x, y].Strength > 200; foundInFrame = true; moveCount++; if ((earlyGame && random.NextDouble() < 1.5 / lastmoveCount) || (strong && random.NextDouble() < 1.5 / lastmoveCount) || random.NextDouble() < 1.0 / lastmoveCount) { var w = normalInputWidth; var container = entryContainer; if (earlyGame) { w = earlyInputWidth; container = entryContainer_early; } else if (strong) { w = strongInputWidth; container = entryContainer_strong; } var convVolume = map.GetVolume(w, playerId, x, y); var direction = moves[y][x]; var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode()); container.Add(entry1); var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry2); var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry3); var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry4); } } } } lastmoveCount = moveCount; if (!foundInFrame) { // player has died break; } } } else { Console.WriteLine("not found"); } } var length = entryContainer.Shuffle(); Console.WriteLine("normal: " + entryContainer.Summary); length = entryContainer_early.Shuffle(); Console.WriteLine("early: " + entryContainer_early.Summary); length = entryContainer_strong.Shuffle(); Console.WriteLine("strong " + entryContainer_strong.Summary); #endregion #region Training var trainer = new AdamTrainer(singleNet) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single"); var trainer_early = new AdamTrainer(singleNet_early) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme_early = new TrainingScheme(singleNet_early, trainer_early, entryContainer_early, "single_early"); var trainer_strong = new AdamTrainer(singleNet_strong) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme_strong = new TrainingScheme(singleNet_strong, trainer_strong, entryContainer_strong, "single_strong"); bool save = true; double lastValidationAcc = 0.0; double lastValidationAcc_early = 0.0; double lastValidationAcc_strong = 0.0; double lastTrainAcc = 0.0; double lastTrainAcc_early = 0.0; double lastTrainAcc_strong = 0.0; do { var normal = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer.L2Decay = 0.05; } Console.WriteLine($"[normal] Epoch #{i + 1}"); if (i % 50 == 0) { trainer.LearningRate = Math.Max(trainer.LearningRate / 5.0, 0.00001); } trainingScheme.RunEpoch(); #region Save Nets if (save) { if (trainingScheme.ValidationAccuracy > lastValidationAcc) { lastValidationAcc = trainingScheme.ValidationAccuracy; lastTrainAcc = trainingScheme.TrainAccuracy; singleNet.SaveNet(NetName); } } #endregion if (Console.KeyAvailable) { break; } } }); var early = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer_early.L2Decay = 0.05; } Console.WriteLine($"[early] Epoch #{i + 1}"); if (i % 50 == 0) { trainer_early.LearningRate = Math.Max(trainer_early.LearningRate / 5.0, 0.00001); } trainingScheme_early.RunEpoch(); #region Save Nets if (save) { if (trainingScheme_early.ValidationAccuracy > lastValidationAcc_early) { lastValidationAcc_early = trainingScheme_early.ValidationAccuracy; lastTrainAcc_early = trainingScheme_early.TrainAccuracy; singleNet_early.SaveNet(NetName_early); } } #endregion if (Console.KeyAvailable) { break; } } }); var strong = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer_strong.L2Decay = 0.05; } Console.WriteLine($"[strong] Epoch #{i + 1}"); if (i % 50 == 0) { trainer_strong.LearningRate = Math.Max(trainer_strong.LearningRate / 5.0, 0.00001); } trainingScheme_strong.RunEpoch(); #region Save Nets if (save) { if (trainingScheme_strong.ValidationAccuracy > lastValidationAcc_strong) { lastValidationAcc_strong = trainingScheme_strong.ValidationAccuracy; lastTrainAcc_strong = trainingScheme_strong.TrainAccuracy; singleNet_strong.SaveNet(NetName_strong); } } #endregion if (Console.KeyAvailable) { break; } } }); Task.WaitAll(new[] { normal, strong, early }); }while (!Console.KeyAvailable); #endregion }