예제 #1
0
        private INet <double> BuildModel(GameModelConfig cfg)
        {
            _logger.LogInformation("Building the model.");
            var model = FluentNet <double> .Create(cfg.InputWidth, cfg.InputHeight, cfg.ImageChannels)
                        .Conv(8, 8, 32).Stride(4).Pad(2)
                        .Pool(2, 2)
                        .Relu()
                        .Conv(4, 4, 64).Stride(2).Pad(2)
                        .Pool(2, 2)
                        .Relu()
                        .Conv(3, 3, 64).Stride(1).Pad(2)
                        .Pool(2, 2)
                        .Relu()
                        .FullyConn(512)
                        .Relu()
                        .FullyConn(cfg.Actions)
                        .Softmax(cfg.Actions)
                        .Build();

            //create model file if not present
            if (!File.Exists(modelPath))
            {
                File.WriteAllText(modelPath, model.ToJson());
            }
            _logger.LogInformation("Finished bilding the model.");
            return(model);
        }
예제 #2
0
        public void CreateTest()
        {
            var net = FluentNet.Create(10, 10, 2)
                      .Relu()
                      .FullyConn(10)
                      .Softmax(10)
                      .Build();

            //net.Forward(new Volume(10, 10, 2));
        }
예제 #3
0
        private void MnistDemo()
        {
            Directory.CreateDirectory(mnistFolder);

            string trainingLabelFilePath = Path.Combine(mnistFolder, trainingLabelFile);
            string trainingImageFilePath = Path.Combine(mnistFolder, trainingImageFile);
            string testingLabelFilePath  = Path.Combine(mnistFolder, testingLabelFile);
            string testingImageFilePath  = Path.Combine(mnistFolder, testingImageFile);

            // Download Mnist files if needed
            Console.WriteLine("Downloading Mnist training files...");
            DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
            DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);
            Console.WriteLine("Downloading Mnist testing files...");
            DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
            DownloadFile(urlMnist + testingImageFile, testingImageFilePath);

            // Load data
            Console.WriteLine("Loading the datasets...");
            this.training = MnistReader.Load(trainingLabelFilePath, trainingImageFilePath);
            this.testing  = MnistReader.Load(testingLabelFilePath, testingImageFilePath);

            if (this.training.Count == 0 || this.testing.Count == 0)
            {
                Console.WriteLine("Missing Mnist training/testing files.");
                Console.ReadKey();
                return;
            }

            // Create network
            this.net = FluentNet.Create(24, 24, 1)
                       .Conv(5, 5, 8).Stride(1).Pad(2)
                       .Relu()
                       .Pool(2, 2).Stride(2)
                       .Conv(5, 5, 16).Stride(1).Pad(2)
                       .Relu()
                       .Pool(3, 3).Stride(3)
                       .FullyConn(10)
                       .Softmax(10)
                       .Build();

            this.trainer = new AdadeltaTrainer(this.net)
            {
                BatchSize = 20,
                L2Decay   = 0.001,
            };

            Console.WriteLine("Convolutional neural network learning...[Press any key to stop]");
            do
            {
                var sample = this.SampleTrainingInstance();
                this.Step(sample);
            } while (!Console.KeyAvailable);
        }
 private FluentNet<double> BuildCNN() =>
     FluentNet<double>.Create(80, 60, 1)
     .Conv(5, 5, 7)
     .Relu()
     .Pool(2, 2)
     .Conv(5, 5, 5)
     .Relu()
     .Pool(2, 2)
     .FullyConn(40)
     .Tanh()
     .FullyConn(4)
     .Softmax(4)
     .Build();
예제 #5
0
        public void MergeTest()
        {
            var branch1 = FluentNet.Create(10, 10, 2)
                          .Relu()
                          .FullyConn(10);
            var branch2 = FluentNet.Create(10, 10, 2)
                          .Relu()
                          .FullyConn(20);

            var net = FluentNet.Merge(branch1, branch2)
                      .FullyConn(5)
                      .Softmax(5)
                      .Build();

            net.Forward(new[] { new Volume(10, 10, 2), new Volume(10, 10, 2) });
        }
예제 #6
0
        private void MnistDemo()
        {
            var datasets = new DataSets();

            if (!datasets.Load(100))
            {
                return;
            }

            // Create network
            this._net = FluentNet <double> .Create(24, 24, 1)
                        .Conv(5, 5, 8).Stride(1).Pad(2)
                        .Relu()
                        .Pool(2, 2).Stride(2)
                        .Conv(5, 5, 16).Stride(1).Pad(2)
                        .Relu()
                        .Pool(3, 3).Stride(3)
                        .FullyConn(10)
                        .Softmax(10)
                        .Build();

            this._trainer = new SgdTrainer <double>(this._net)
            {
                LearningRate = 0.01,
                BatchSize    = 20,
                L2Decay      = 0.001,
                Momentum     = 0.9
            };

            Console.WriteLine("Convolutional neural network learning...[Press any key to stop]");
            do
            {
                var trainSample = datasets.Train.NextBatch(this._trainer.BatchSize);
                Train(trainSample.Item1, trainSample.Item2, trainSample.Item3);

                var testSample = datasets.Test.NextBatch(this._trainer.BatchSize);
                Test(testSample.Item1, testSample.Item3, this._testAccWindow);

                Console.WriteLine("Loss: {0} Train accuracy: {1}% Test accuracy: {2}%", this._trainer.Loss,
                                  Math.Round(this._trainAccWindow.Items.Average() * 100.0, 2),
                                  Math.Round(this._testAccWindow.Items.Average() * 100.0, 2));

                Console.WriteLine("Example seen: {0} Fwd: {1}ms Bckw: {2}ms", this._stepCount,
                                  Math.Round(this._trainer.ForwardTimeMs, 2),
                                  Math.Round(this._trainer.BackwardTimeMs, 2));
            } while (!Console.KeyAvailable);
        }
        public void FluentNetSerialization()
        {
            // Fluent version
            var net = FluentNet <double> .Create(24, 24, 1)
                      .Conv(5, 5, 8).Stride(1).Pad(2)
                      .Relu()
                      .Pool(2, 2).Stride(2)
                      .Conv(5, 5, 16).Stride(1).Pad(2)
                      .Relu()
                      .Pool(3, 3).Stride(3)
                      .FullyConn(10)
                      .Softmax(10)
                      .Build();

            var json         = net.ToJson();
            var deserialized = SerializationExtensions.FromJson <double>(json);

            Assert.AreEqual(9, deserialized.Layers.Count);
        }
예제 #8
0
        public MainWindow()
        {
            InitializeComponent();

            Init();
            InitPalette();


            StartNes();

            var net = FluentNet <double> .Create(WIDTH, HEIGHT, 4)
                      .Conv(5, 5, 8).Stride(1).Pad(2)
                      .Relu()
                      .Pool(2, 2).Stride(2)
                      .Conv(5, 5, 16).Stride(1).Pad(2)
                      .Relu()
                      .Pool(3, 3).Stride(3)
                      .FullyConn(10)
                      .Softmax(10)
                      .Build();
        }
예제 #9
0
        public void FluentBinaryNetSerializerTest()
        {
            var net = FluentNet.Create(5, 5, 3)
                      .Conv(2, 2, 16)
                      .FullyConn(3)
                      .Softmax(3)
                      .Build();

            // Serialize (binary)
            using (var ms = new MemoryStream())
            {
                net.SaveBinary(ms);
                ms.Position = 0;

                // Deserialize (binary)
                FluentNet deserialized = SerializationExtensions.LoadBinary(ms) as FluentNet;

                Assert.IsNotNull(deserialized);
                Assert.AreEqual(net.InputLayers.Count, deserialized.InputLayers.Count);

                // TODO: improve test
            }
        }
예제 #10
0
        public static void Run()
        {
            const string NetName = "net.dat";

            var random         = new Random();
            var entryContainer = new EntryContainer();

            #region Load Net

            var convInputWith = 11; // Will extract 11x11 area
            if (convInputWith % 2 == 0)
            {
                throw new ArgumentException("convInputWith must be odd");
            }

            // Load IA or initialize new network if not found - Direction choice
            INet singleNet = NetExtension.LoadOrCreateNet(NetName, () =>
            {
                var net = FluentNet.Create(convInputWith, convInputWith, 3)
                          .Conv(3, 3, 16).Stride(2)
                          .Tanh()
                          .Conv(2, 2, 16)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            #endregion

            #region Load data

            var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader
            int total    = hltFiles.Count;
            Console.WriteLine($"Loading {total} games...");

            foreach (var file in hltFiles)
            {
                Console.WriteLine(total--);
                HltReader reader = new HltReader(file);

                var playerId     = -1;
                var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman"));

                if (playerToCopy == null)
                {
                    Console.WriteLine("Player not found");
                    continue;
                }

                playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1;

                var width  = reader.Width;
                var height = reader.Height;

                int lastmoveCount = 1;

                for (var frame = 0; frame < reader.FrameCount - 1; frame++)
                {
                    var currentFrame = reader.GetFrame(frame);
                    var map          = currentFrame.map;
                    var moves        = currentFrame.moves;

                    var helper = new Helper(map, (ushort)playerId);

                    bool foundInFrame = false;
                    int  moveCount    = 0;

                    // moves
                    for (ushort x = 0; x < width; x++)
                    {
                        for (ushort y = 0; y < height; y++)
                        {
                            if (map[x, y].Owner == playerId)
                            {
                                foundInFrame = true;
                                moveCount++;

                                if (random.NextDouble() < 1.0 / lastmoveCount)
                                {
                                    var convVolume = map.GetVolume(convInputWith, playerId, x, y); // Input
                                    var direction  = moves[y][x];                                  // Output

                                    var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry1);

                                    // Data augmentation
                                    var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry2);
                                    var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry3);
                                    var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry4);
                                }
                            }
                        }
                    }

                    lastmoveCount = moveCount;

                    if (!foundInFrame)
                    {
                        // player has died
                        break;
                    }
                }
            }

            var length = entryContainer.Shuffle();
            Console.WriteLine(entryContainer.Summary);

            #endregion

            #region Training

            var trainer = new AdamTrainer(singleNet)
            {
                BatchSize = 1024, LearningRate = 0.1, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var    trainingScheme    = new TrainingScheme(singleNet, trainer, entryContainer, "single");
            bool   save              = true;
            double lastValidationAcc = 0.0;

            do
            {
                for (int i = 0; i < 1000; i++)
                {
                    if (i > 5)
                    {
                        trainer.L2Decay = 0.001;
                    }

                    Console.WriteLine($"Epoch #{i + 1}");

                    if (i % 15 == 0)
                    {
                        trainer.LearningRate = Math.Max(trainer.LearningRate / 10.0, 0.00001);
                    }

                    trainingScheme.RunEpoch();

                    #region Save Nets

                    if (save)
                    {
                        // Save if validation accuracy has improved
                        if (trainingScheme.ValidationAccuracy > lastValidationAcc)
                        {
                            lastValidationAcc = trainingScheme.ValidationAccuracy;
                            singleNet.SaveNet(NetName);
                        }
                    }
                    #endregion

                    if (Console.KeyAvailable)
                    {
                        break;
                    }
                }
            } while (!Console.KeyAvailable);

            #endregion
        }
예제 #11
0
        public static void Run()
        {
            var random = new Random(RandomUtilities.Seed);

            int normalInputWidth = 19;
            int earlyInputWidth  = 19;
            int strongInputWidth = 19;

            string NetName        = $"net.dat";
            string NetName_early  = $"net_early.dat";
            string NetName_strong = $"net_strong.dat";

            var entryContainer        = new EntryContainer();
            var entryContainer_early  = new EntryContainer();
            var entryContainer_strong = new EntryContainer();

            #region Load Net

            INet singleNet = NetExtension.LoadOrCreateNet(NetName, () =>
            {
                var net = FluentNet.Create(normalInputWidth, normalInputWidth, 3)
                          .Conv(5, 5, 16).Stride(5).Pad(2)
                          .Tanh()
                          .Conv(3, 3, 16).Stride(1).Pad(1)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            INet singleNet_early = NetExtension.LoadOrCreateNet(NetName_early, () =>
            {
                var net = FluentNet.Create(earlyInputWidth, earlyInputWidth, 3)
                          .Conv(5, 5, 16).Stride(5).Pad(2)
                          .Tanh()
                          .Conv(3, 3, 16).Stride(1).Pad(1)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            INet singleNet_strong = NetExtension.LoadOrCreateNet(NetName_strong, () =>
            {
                var net = FluentNet.Create(strongInputWidth, strongInputWidth, 3)
                          .Conv(5, 5, 16).Stride(5).Pad(2)
                          .Tanh()
                          .Conv(3, 3, 16).Stride(1).Pad(1)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            #endregion

            #region Load data

            var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader
            int total    = hltFiles.Count;
            Console.WriteLine($"Loading {total} games...");

            foreach (var file in hltFiles)
            {
                Console.WriteLine(total--);
                HltReader reader = new HltReader(file);

                var playerId     = -1;
                var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman"));

                if (playerToCopy != null)
                {
                    playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1;
                }

                if (playerId != -1)
                {
                    var width  = reader.Width;
                    var height = reader.Height;

                    int lastmoveCount = 1;

                    for (var frame = 0; frame < reader.FrameCount - 1; frame++)
                    {
                        bool earlyGame = lastmoveCount < 25;

                        var currentFrame = reader.GetFrame(frame);
                        var map          = currentFrame.map;
                        var moves        = currentFrame.moves;

                        var helper = new Helper(map, (ushort)playerId);

                        bool foundInFrame = false;
                        int  moveCount    = 0;

                        // moves
                        for (ushort x = 0; x < width; x++)
                        {
                            for (ushort y = 0; y < height; y++)
                            {
                                if (map[x, y].Owner == playerId)
                                {
                                    bool strong = map[x, y].Strength > 200;
                                    foundInFrame = true;
                                    moveCount++;

                                    if ((earlyGame && random.NextDouble() < 1.5 / lastmoveCount) || (strong && random.NextDouble() < 1.5 / lastmoveCount) || random.NextDouble() < 1.0 / lastmoveCount)
                                    {
                                        var w         = normalInputWidth;
                                        var container = entryContainer;

                                        if (earlyGame)
                                        {
                                            w         = earlyInputWidth;
                                            container = entryContainer_early;
                                        }
                                        else if (strong)
                                        {
                                            w         = strongInputWidth;
                                            container = entryContainer_strong;
                                        }

                                        var convVolume = map.GetVolume(w, playerId, x, y);

                                        var direction = moves[y][x];

                                        var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode());
                                        container.Add(entry1);
                                        var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode());
                                        container.Add(entry2);
                                        var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode());
                                        container.Add(entry3);
                                        var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode());
                                        container.Add(entry4);
                                    }
                                }
                            }
                        }

                        lastmoveCount = moveCount;

                        if (!foundInFrame)
                        {
                            // player has died
                            break;
                        }
                    }
                }
                else
                {
                    Console.WriteLine("not found");
                }
            }

            var length = entryContainer.Shuffle();
            Console.WriteLine("normal: " + entryContainer.Summary);
            length = entryContainer_early.Shuffle();
            Console.WriteLine("early: " + entryContainer_early.Summary);
            length = entryContainer_strong.Shuffle();
            Console.WriteLine("strong " + entryContainer_strong.Summary);

            #endregion

            #region Training

            var trainer = new AdamTrainer(singleNet)
            {
                BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single");

            var trainer_early = new AdamTrainer(singleNet_early)
            {
                BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var trainingScheme_early = new TrainingScheme(singleNet_early, trainer_early, entryContainer_early, "single_early");

            var trainer_strong = new AdamTrainer(singleNet_strong)
            {
                BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var trainingScheme_strong = new TrainingScheme(singleNet_strong, trainer_strong, entryContainer_strong, "single_strong");

            bool   save = true;
            double lastValidationAcc        = 0.0;
            double lastValidationAcc_early  = 0.0;
            double lastValidationAcc_strong = 0.0;
            double lastTrainAcc             = 0.0;
            double lastTrainAcc_early       = 0.0;
            double lastTrainAcc_strong      = 0.0;
            do
            {
                var normal = Task.Factory.StartNew(() =>
                {
                    for (int i = 0; i < 50; i++)
                    {
                        if (i > 5)
                        {
                            trainer.L2Decay = 0.05;
                        }

                        Console.WriteLine($"[normal] Epoch #{i + 1}");

                        if (i % 50 == 0)
                        {
                            trainer.LearningRate = Math.Max(trainer.LearningRate / 5.0, 0.00001);
                        }

                        trainingScheme.RunEpoch();

                        #region Save Nets

                        if (save)
                        {
                            if (trainingScheme.ValidationAccuracy > lastValidationAcc)
                            {
                                lastValidationAcc = trainingScheme.ValidationAccuracy;
                                lastTrainAcc      = trainingScheme.TrainAccuracy;
                                singleNet.SaveNet(NetName);
                            }
                        }
                        #endregion

                        if (Console.KeyAvailable)
                        {
                            break;
                        }
                    }
                });

                var early = Task.Factory.StartNew(() =>
                {
                    for (int i = 0; i < 50; i++)
                    {
                        if (i > 5)
                        {
                            trainer_early.L2Decay = 0.05;
                        }

                        Console.WriteLine($"[early] Epoch #{i + 1}");

                        if (i % 50 == 0)
                        {
                            trainer_early.LearningRate = Math.Max(trainer_early.LearningRate / 5.0, 0.00001);
                        }

                        trainingScheme_early.RunEpoch();

                        #region Save Nets

                        if (save)
                        {
                            if (trainingScheme_early.ValidationAccuracy > lastValidationAcc_early)
                            {
                                lastValidationAcc_early = trainingScheme_early.ValidationAccuracy;
                                lastTrainAcc_early      = trainingScheme_early.TrainAccuracy;
                                singleNet_early.SaveNet(NetName_early);
                            }
                        }
                        #endregion

                        if (Console.KeyAvailable)
                        {
                            break;
                        }
                    }
                });

                var strong = Task.Factory.StartNew(() =>
                {
                    for (int i = 0; i < 50; i++)
                    {
                        if (i > 5)
                        {
                            trainer_strong.L2Decay = 0.05;
                        }

                        Console.WriteLine($"[strong] Epoch #{i + 1}");

                        if (i % 50 == 0)
                        {
                            trainer_strong.LearningRate = Math.Max(trainer_strong.LearningRate / 5.0, 0.00001);
                        }

                        trainingScheme_strong.RunEpoch();

                        #region Save Nets

                        if (save)
                        {
                            if (trainingScheme_strong.ValidationAccuracy > lastValidationAcc_strong)
                            {
                                lastValidationAcc_strong = trainingScheme_strong.ValidationAccuracy;
                                lastTrainAcc_strong      = trainingScheme_strong.TrainAccuracy;
                                singleNet_strong.SaveNet(NetName_strong);
                            }
                        }
                        #endregion

                        if (Console.KeyAvailable)
                        {
                            break;
                        }
                    }
                });

                Task.WaitAll(new[] { normal, strong, early });
            }while (!Console.KeyAvailable);

            #endregion
        }