/// <summary>
                /// Get the current error of training dataset with multitherading
                /// </summary>
                /// <returns>Errors of the set</returns>
                public static double[] GetCurrentErrorMultitherading(this AdamTrainer adam)
                {
                    double[] Errors = new double[adam.TrainingDataset.Size];
                    int      i      = 0;

                    foreach (IOMetaDataSetItem <double[]> item in adam.TrainingDataset)
                    {
                        adam.TargetNetwork.UpdatePositiveMultitherading(item.DataIn);

                        Errors[i] = 0.5 * adam.GetMeanSquareError(adam.TargetNetwork.OutputValues, item.DataOut).Select((double v) => (v * v)).Sum();
                        i++;
                    }

                    return(Errors);
                }
                /// <summary>
                /// Training the network with multitherading
                /// </summary>
                /// <param name="Iteration">Number of iteration</param>
                /// <param name="MinimunError">Minimun error of training</param>
                /// <param name="AutoExit">Auto exit when error less then minimun error</param>
                /// <returns>A enumerable for errors</returns>
                public static IEnumerable <double> TrainMultitherading(this AdamTrainer adam, int Iteration, double MinimunError, bool AutoExit)
                {
                    // A global gradient summary for adam-grad
                    // double[] -- synapse of one neuron
                    // List<double[]> -- neurons of one layer
                    // List<List<double[]>> -- layers of the network
                    List <List <double[]> > MSummary = new List <List <double[]> >();
                    List <List <double[]> > NSummary = new List <List <double[]> >();

                    adam.InitializeNetworkArray(ref MSummary);
                    adam.InitializeNetworkArray(ref NSummary);

                    // A global gradient summary for adam-grad
                    // double[] -- synapse of one layer of bias neuron
                    // List<double[]> -- layers of the network
                    List <double[]> biasMSummary = new List <double[]>();
                    List <double[]> biasNSummary = new List <double[]>();

                    adam.InitializeBiasArray(ref biasMSummary);
                    adam.InitializeBiasArray(ref biasNSummary);

                    adam.adamDecent.ClearCorrectAccumulate();

                    for (int iteration = 0; iteration < Iteration; iteration++)
                    {
                        foreach (IOMetaDataSetItem <double[]> item in adam.TrainingDataset)
                        {
                            // A new weight array
                            List <List <double[]> > newWeights = new List <List <double[]> >();
                            adam.InitializeNetworkArray(ref newWeights);

                            // A new bias array
                            List <double[]> newBias = new List <double[]>();
                            adam.InitializeBiasArray(ref newBias);

                            // Update the network with sample
                            adam.TargetNetwork.UpdatePositiveMultitherading(item.DataIn);

                            // Traversing all layers
                            double[] nexterrors = new double[0];

                            for (int layer = adam.TargetNetwork.Layers - 1; layer >= 0; layer--)
                            {
                                // Get number of synapse pre neurons in this layer
                                int SynapseCount = (layer == 0) ? (adam.TargetNetwork.Inputs) : (adam.TargetNetwork.Neurons[layer - 1].Length);

                                // Get the number of neurons in this layer
                                int NeuronsCount = adam.TargetNetwork.Neurons[layer].Length;

                                // Get the error from the back layer (backpropagation)
                                // When start (in the final layer), the error is network error
                                double[] errors = new double[NeuronsCount];

                                if (layer == adam.TargetNetwork.Layers - 1)
                                {
                                    errors = adam.GetMeanSquareError(adam.TargetNetwork.OutputValues, item.DataOut); // Start, final layer
                                }
                                else
                                {
                                    errors = nexterrors;
                                }

                                // Create a new error array for next layer
                                nexterrors = new double[SynapseCount];
                                Array.Fill(nexterrors, 0);
                                // Create a variable to storage bias errors
                                // Traversing all the neurons in this layer
                                Parallel.For(0, NeuronsCount, delegate(int neuron)
                                {
                                    Neuron n = adam.TargetNetwork.Neurons[layer][neuron];

                                    // Traversing all synapse in this neuron and calculate the gradients
                                    double[] weights   = n.Weights;
                                    double[] gradients = new double[SynapseCount];
                                    double[] msummary  = MSummary[layer][neuron];
                                    double[] nsummary  = NSummary[layer][neuron];

                                    for (int synapse = 0; synapse < SynapseCount; synapse++)
                                    {
                                        double prev = (layer == 0) ? (adam.TargetNetwork.InputNeurons[synapse].OutputValue) : (adam.TargetNetwork.Neurons[layer - 1][synapse].OutputValue);

                                        // Calculate gradient => g = - E * d(f(Is))/d(Is) * Os
                                        double gradient = -errors[neuron] * n.TransferFunction.Derivatives(n.MiddleValue) * prev;

                                        // Check the gradient to avoid NaN values
                                        if (double.IsNaN(gradient))
                                        {
                                            throw new Exception("Gradient is NaN");
                                        }

                                        // Save the gradient
                                        gradients[synapse] = gradient;

                                        // Statistic next error => next error = Σ(Ws * e)
                                        nexterrors[synapse] += adam.TargetNetwork.Neurons[layer][neuron].Weights[synapse] * errors[neuron];
                                    }

                                    // Gradient decent
                                    double[] biasunused  = new double[0];
                                    object biasparameter = new List <double[]>()
                                    {
                                        msummary, nsummary
                                    };
                                    adam.adamDecent.Update(ref weights, gradients, ref biasunused, null, ref biasparameter);

                                    // Update the summary
                                    MSummary[layer][neuron] = (biasparameter as List <double[]>)[0];
                                    NSummary[layer][neuron] = (biasparameter as List <double[]>)[1];

                                    // Save the new weights
                                    newWeights[layer][neuron] = weights;
                                });

                                // Now calculate the bias, but only apart from the last layer have bias
                                // The bias[a] is actually connect to layer[b]
                                // We had get the errors from neurons calculation
                                double[] bias          = adam.TargetNetwork.Neurons[layer].Select((Neuron p) => (p.Bias)).ToArray();
                                double[] biasGradients = new double[NeuronsCount];
                                double[] biasmSummary  = biasMSummary[layer];
                                double[] biasnSummary  = biasNSummary[layer];
                                for (int synapse = 0; synapse < NeuronsCount; synapse++)
                                {
                                    // Calculate gradient => g = - E * Os
                                    double gradient = -errors[synapse] * adam.TargetNetwork.BiasNeurons[layer].OutputValue;

                                    // Save the gradient
                                    biasGradients[synapse] = gradient;
                                }

                                // Gradient decent
                                double[] unused    = new double[0];
                                object   parameter = new List <double[]>()
                                {
                                    biasmSummary, biasnSummary
                                };
                                adam.adamDecent.Update(ref bias, biasGradients, ref unused, null, ref parameter);

                                // Update the summary
                                biasMSummary[layer] = (parameter as List <double[]>)[0];
                                biasNSummary[layer] = (parameter as List <double[]>)[1];

                                // Save the new bias
                                newBias[layer] = bias;
                            }

                            // Update the weights by new weights
                            for (int layer = 0; layer < adam.TargetNetwork.Layers; layer++)
                            {
                                // Get number of synapse pre neurons in this layer
                                int SynapseCount = (layer == 0) ? (adam.TargetNetwork.Inputs) : (adam.TargetNetwork.Neurons[layer - 1].Length);

                                // Get the number of neurons in this layer
                                int NeuronsCount = adam.TargetNetwork.Neurons[layer].Length;

                                // Traversing all the neurons in this layer
                                for (int neuron = 0; neuron < NeuronsCount; neuron++)
                                {
                                    adam.TargetNetwork.Neurons[layer][neuron].Weights = newWeights[layer][neuron];
                                    adam.TargetNetwork.Neurons[layer][neuron].Bias    = newBias[layer][neuron];
                                }
                            }

                            // Correct accumulate
                            adam.adamDecent.CorrectAccumulate();
                        }

                        // Return the deviation of the new iteration
                        double error = adam.GetCurrentErrorMultitherading().Average();
                        yield return(error);

                        if (error <= MinimunError)
                        {
                            if (AutoExit == true)
                            {
                                break;
                            }
                        }
                    }
                }
Exemplo n.º 3
0
        public static void Run()
        {
            const string NetName = "net.dat";

            var random         = new Random();
            var entryContainer = new EntryContainer();

            #region Load Net

            var convInputWith = 11; // Will extract 11x11 area
            if (convInputWith % 2 == 0)
            {
                throw new ArgumentException("convInputWith must be odd");
            }

            // Load IA or initialize new network if not found - Direction choice
            INet singleNet = NetExtension.LoadOrCreateNet(NetName, () =>
            {
                var net = FluentNet.Create(convInputWith, convInputWith, 3)
                          .Conv(3, 3, 16).Stride(2)
                          .Tanh()
                          .Conv(2, 2, 16)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            #endregion

            #region Load data

            var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader
            int total    = hltFiles.Count;
            Console.WriteLine($"Loading {total} games...");

            foreach (var file in hltFiles)
            {
                Console.WriteLine(total--);
                HltReader reader = new HltReader(file);

                var playerId     = -1;
                var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman"));

                if (playerToCopy == null)
                {
                    Console.WriteLine("Player not found");
                    continue;
                }

                playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1;

                var width  = reader.Width;
                var height = reader.Height;

                int lastmoveCount = 1;

                for (var frame = 0; frame < reader.FrameCount - 1; frame++)
                {
                    var currentFrame = reader.GetFrame(frame);
                    var map          = currentFrame.map;
                    var moves        = currentFrame.moves;

                    var helper = new Helper(map, (ushort)playerId);

                    bool foundInFrame = false;
                    int  moveCount    = 0;

                    // moves
                    for (ushort x = 0; x < width; x++)
                    {
                        for (ushort y = 0; y < height; y++)
                        {
                            if (map[x, y].Owner == playerId)
                            {
                                foundInFrame = true;
                                moveCount++;

                                if (random.NextDouble() < 1.0 / lastmoveCount)
                                {
                                    var convVolume = map.GetVolume(convInputWith, playerId, x, y); // Input
                                    var direction  = moves[y][x];                                  // Output

                                    var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry1);

                                    // Data augmentation
                                    var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry2);
                                    var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry3);
                                    var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode());
                                    entryContainer.Add(entry4);
                                }
                            }
                        }
                    }

                    lastmoveCount = moveCount;

                    if (!foundInFrame)
                    {
                        // player has died
                        break;
                    }
                }
            }

            var length = entryContainer.Shuffle();
            Console.WriteLine(entryContainer.Summary);

            #endregion

            #region Training

            var trainer = new AdamTrainer(singleNet)
            {
                BatchSize = 1024, LearningRate = 0.1, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var    trainingScheme    = new TrainingScheme(singleNet, trainer, entryContainer, "single");
            bool   save              = true;
            double lastValidationAcc = 0.0;

            do
            {
                for (int i = 0; i < 1000; i++)
                {
                    if (i > 5)
                    {
                        trainer.L2Decay = 0.001;
                    }

                    Console.WriteLine($"Epoch #{i + 1}");

                    if (i % 15 == 0)
                    {
                        trainer.LearningRate = Math.Max(trainer.LearningRate / 10.0, 0.00001);
                    }

                    trainingScheme.RunEpoch();

                    #region Save Nets

                    if (save)
                    {
                        // Save if validation accuracy has improved
                        if (trainingScheme.ValidationAccuracy > lastValidationAcc)
                        {
                            lastValidationAcc = trainingScheme.ValidationAccuracy;
                            singleNet.SaveNet(NetName);
                        }
                    }
                    #endregion

                    if (Console.KeyAvailable)
                    {
                        break;
                    }
                }
            } while (!Console.KeyAvailable);

            #endregion
        }
Exemplo n.º 4
0
        public static void Run()
        {
            var random = new Random(RandomUtilities.Seed);

            int normalInputWidth = 19;
            int earlyInputWidth  = 19;
            int strongInputWidth = 19;

            string NetName        = $"net.dat";
            string NetName_early  = $"net_early.dat";
            string NetName_strong = $"net_strong.dat";

            var entryContainer        = new EntryContainer();
            var entryContainer_early  = new EntryContainer();
            var entryContainer_strong = new EntryContainer();

            #region Load Net

            INet singleNet = NetExtension.LoadOrCreateNet(NetName, () =>
            {
                var net = FluentNet.Create(normalInputWidth, normalInputWidth, 3)
                          .Conv(5, 5, 16).Stride(5).Pad(2)
                          .Tanh()
                          .Conv(3, 3, 16).Stride(1).Pad(1)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            INet singleNet_early = NetExtension.LoadOrCreateNet(NetName_early, () =>
            {
                var net = FluentNet.Create(earlyInputWidth, earlyInputWidth, 3)
                          .Conv(5, 5, 16).Stride(5).Pad(2)
                          .Tanh()
                          .Conv(3, 3, 16).Stride(1).Pad(1)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            INet singleNet_strong = NetExtension.LoadOrCreateNet(NetName_strong, () =>
            {
                var net = FluentNet.Create(strongInputWidth, strongInputWidth, 3)
                          .Conv(5, 5, 16).Stride(5).Pad(2)
                          .Tanh()
                          .Conv(3, 3, 16).Stride(1).Pad(1)
                          .Tanh()
                          .FullyConn(100)
                          .Relu()
                          .FullyConn(5)
                          .Softmax(5).Build();

                return(net);
            });

            #endregion

            #region Load data

            var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader
            int total    = hltFiles.Count;
            Console.WriteLine($"Loading {total} games...");

            foreach (var file in hltFiles)
            {
                Console.WriteLine(total--);
                HltReader reader = new HltReader(file);

                var playerId     = -1;
                var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman"));

                if (playerToCopy != null)
                {
                    playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1;
                }

                if (playerId != -1)
                {
                    var width  = reader.Width;
                    var height = reader.Height;

                    int lastmoveCount = 1;

                    for (var frame = 0; frame < reader.FrameCount - 1; frame++)
                    {
                        bool earlyGame = lastmoveCount < 25;

                        var currentFrame = reader.GetFrame(frame);
                        var map          = currentFrame.map;
                        var moves        = currentFrame.moves;

                        var helper = new Helper(map, (ushort)playerId);

                        bool foundInFrame = false;
                        int  moveCount    = 0;

                        // moves
                        for (ushort x = 0; x < width; x++)
                        {
                            for (ushort y = 0; y < height; y++)
                            {
                                if (map[x, y].Owner == playerId)
                                {
                                    bool strong = map[x, y].Strength > 200;
                                    foundInFrame = true;
                                    moveCount++;

                                    if ((earlyGame && random.NextDouble() < 1.5 / lastmoveCount) || (strong && random.NextDouble() < 1.5 / lastmoveCount) || random.NextDouble() < 1.0 / lastmoveCount)
                                    {
                                        var w         = normalInputWidth;
                                        var container = entryContainer;

                                        if (earlyGame)
                                        {
                                            w         = earlyInputWidth;
                                            container = entryContainer_early;
                                        }
                                        else if (strong)
                                        {
                                            w         = strongInputWidth;
                                            container = entryContainer_strong;
                                        }

                                        var convVolume = map.GetVolume(w, playerId, x, y);

                                        var direction = moves[y][x];

                                        var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode());
                                        container.Add(entry1);
                                        var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode());
                                        container.Add(entry2);
                                        var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode());
                                        container.Add(entry3);
                                        var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode());
                                        container.Add(entry4);
                                    }
                                }
                            }
                        }

                        lastmoveCount = moveCount;

                        if (!foundInFrame)
                        {
                            // player has died
                            break;
                        }
                    }
                }
                else
                {
                    Console.WriteLine("not found");
                }
            }

            var length = entryContainer.Shuffle();
            Console.WriteLine("normal: " + entryContainer.Summary);
            length = entryContainer_early.Shuffle();
            Console.WriteLine("early: " + entryContainer_early.Summary);
            length = entryContainer_strong.Shuffle();
            Console.WriteLine("strong " + entryContainer_strong.Summary);

            #endregion

            #region Training

            var trainer = new AdamTrainer(singleNet)
            {
                BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single");

            var trainer_early = new AdamTrainer(singleNet_early)
            {
                BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var trainingScheme_early = new TrainingScheme(singleNet_early, trainer_early, entryContainer_early, "single_early");

            var trainer_strong = new AdamTrainer(singleNet_strong)
            {
                BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8
            };
            var trainingScheme_strong = new TrainingScheme(singleNet_strong, trainer_strong, entryContainer_strong, "single_strong");

            bool   save = true;
            double lastValidationAcc        = 0.0;
            double lastValidationAcc_early  = 0.0;
            double lastValidationAcc_strong = 0.0;
            double lastTrainAcc             = 0.0;
            double lastTrainAcc_early       = 0.0;
            double lastTrainAcc_strong      = 0.0;
            do
            {
                var normal = Task.Factory.StartNew(() =>
                {
                    for (int i = 0; i < 50; i++)
                    {
                        if (i > 5)
                        {
                            trainer.L2Decay = 0.05;
                        }

                        Console.WriteLine($"[normal] Epoch #{i + 1}");

                        if (i % 50 == 0)
                        {
                            trainer.LearningRate = Math.Max(trainer.LearningRate / 5.0, 0.00001);
                        }

                        trainingScheme.RunEpoch();

                        #region Save Nets

                        if (save)
                        {
                            if (trainingScheme.ValidationAccuracy > lastValidationAcc)
                            {
                                lastValidationAcc = trainingScheme.ValidationAccuracy;
                                lastTrainAcc      = trainingScheme.TrainAccuracy;
                                singleNet.SaveNet(NetName);
                            }
                        }
                        #endregion

                        if (Console.KeyAvailable)
                        {
                            break;
                        }
                    }
                });

                var early = Task.Factory.StartNew(() =>
                {
                    for (int i = 0; i < 50; i++)
                    {
                        if (i > 5)
                        {
                            trainer_early.L2Decay = 0.05;
                        }

                        Console.WriteLine($"[early] Epoch #{i + 1}");

                        if (i % 50 == 0)
                        {
                            trainer_early.LearningRate = Math.Max(trainer_early.LearningRate / 5.0, 0.00001);
                        }

                        trainingScheme_early.RunEpoch();

                        #region Save Nets

                        if (save)
                        {
                            if (trainingScheme_early.ValidationAccuracy > lastValidationAcc_early)
                            {
                                lastValidationAcc_early = trainingScheme_early.ValidationAccuracy;
                                lastTrainAcc_early      = trainingScheme_early.TrainAccuracy;
                                singleNet_early.SaveNet(NetName_early);
                            }
                        }
                        #endregion

                        if (Console.KeyAvailable)
                        {
                            break;
                        }
                    }
                });

                var strong = Task.Factory.StartNew(() =>
                {
                    for (int i = 0; i < 50; i++)
                    {
                        if (i > 5)
                        {
                            trainer_strong.L2Decay = 0.05;
                        }

                        Console.WriteLine($"[strong] Epoch #{i + 1}");

                        if (i % 50 == 0)
                        {
                            trainer_strong.LearningRate = Math.Max(trainer_strong.LearningRate / 5.0, 0.00001);
                        }

                        trainingScheme_strong.RunEpoch();

                        #region Save Nets

                        if (save)
                        {
                            if (trainingScheme_strong.ValidationAccuracy > lastValidationAcc_strong)
                            {
                                lastValidationAcc_strong = trainingScheme_strong.ValidationAccuracy;
                                lastTrainAcc_strong      = trainingScheme_strong.TrainAccuracy;
                                singleNet_strong.SaveNet(NetName_strong);
                            }
                        }
                        #endregion

                        if (Console.KeyAvailable)
                        {
                            break;
                        }
                    }
                });

                Task.WaitAll(new[] { normal, strong, early });
            }while (!Console.KeyAvailable);

            #endregion
        }
Exemplo n.º 5
0
        /// <summary>
        /// main training module
        /// </summary>
        /// <param name="cfg">Model configurations</param>
        /// <param name="gameState">Game State module with access to game environment and dino</param>
        /// <param name="observe">Flag to indicate wherther the model is to be trained(weight updates), else just play</param>
        public void TrainModel(GameModelConfig cfg, GameState gameState, bool justPlay = false)
        {
            CacheUtils.InitCache(
                ("epsilon", cfg.InitialEpsilon),
                ("time", 0),
                ("D", new Queue <(Volume <double>, int, double, Volume <double>, bool)>())); //initial variable caching, done only once

            var model    = BuildModel(cfg);
            var lastTime = DateTime.Now;
            //store the previous observations in replay memory
            var D = CacheUtils.LoadObj <Queue <(Volume <double>, int, double, Volume <double>, bool)> >("D"); //load from file system
            // get the first state by doing nothing
            var do_nothing = new double[cfg.Actions];

            do_nothing[0] = 1;                                                                         //0 => do nothing,
                                                                                                       //1 => jump
            var(x_t, r_0, terminal) = gameState.GetState(do_nothing, cfg.InputWidth, cfg.InputHeight); //get next step after performing the action
            var s_t = BuilderInstance.Volume.From(x_t.Repeat(4), new Shape(cfg.InputWidth, cfg.InputHeight, cfg.ImageChannels));
            //s_t.ReShape(1, cfg.ImageRows, cfg.ImageCols, cfg.ImageChannels);
            var    initial_state = x_t;
            double observe;
            double epsilon;


            model = SerializationExtensions.FromJson <double>(File.ReadAllText(modelPath));

            var trainer = new AdamTrainer(model)
            {
                LearningRate = cfg.LearningRate
            };

            if (justPlay)
            {
                observe = 999999999; //We keep observe, never train
                epsilon = cfg.FinalEpsilon;
            }
            else //We go to training mode
            {
                observe = cfg.Observation;
                epsilon = CacheUtils.LoadObj <double>("epsilon");
            }

            int t = CacheUtils.LoadObj <int>("time"); // resume from the previous time step stored in file system

            while (true)                              //endless running
            {
                double loss         = 0;
                double Q_sa         = 0;
                int    action_index = 0;
                double r_t          = 0;                       //reward at 4
                var    a_t          = new double[cfg.Actions]; //action at t

                //choose an action epsilon greedy
                if (t % cfg.FramePerAction == 0)         //parameter to skip frames for actions
                {
                    if (_random.NextDouble() <= epsilon) //randomly explore an action
                    {
                        _logger.LogInformation("----------Random Action----------");
                        action_index      = _random.Next(cfg.Actions);
                        a_t[action_index] = 1;
                    }
                    else //predict the output
                    {
                        model.Forward(s_t); //input a stack of 4 images, get the prediction
                        var q = model.GetPrediction();
                        action_index      = q[0]; //chosing index with maximum q value
                        a_t[action_index] = 1;    //o=> do nothing, 1=> jump
                    }
                }

                //We reduced the epsilon (exploration parameter) gradually
                if (epsilon > cfg.FinalEpsilon && t > observe)
                {
                    epsilon -= (cfg.InitialEpsilon - cfg.FinalEpsilon) / cfg.Explore;
                }

                //run the selected action and observed next state and reward
                double[] x_t1;
                (x_t1, r_t, terminal) = gameState.GetState(a_t, cfg.InputWidth, cfg.InputHeight);
                _logger.LogInformation($"fps: { 1 / (DateTime.Now - lastTime).TotalSeconds }");       //helpful for measuring frame rate
                lastTime = DateTime.Now;
                var s_t1 = BuilderInstance.Volume.From(s_t.ToArray().StackAndShift(x_t1), s_t.Shape); //append the new image to input stack and remove the first one


                //store the transition in D
                D.Enqueue((s_t, action_index, r_t, s_t1, terminal));
                if (D.Count > cfg.ReplayMemory)
                {
                    D.Dequeue();
                }

                //only train if done observing
                if (t > observe)
                {
                    //var minibatch
                }
            }
        }
Exemplo n.º 6
0
        public void Run(string model, string trainingSet, string testSet)
        {
#if GPU
            Log.Info("Enabling GPU mode ...");
            BuilderInstance <double> .Volume = new ConvNetSharp.Volume.GPU.Double.VolumeBuilder();
#endif

            Log.Info("Loading model ...");

            var network = (File.Exists(model)) ? Network.FromFile(model) : Network.CreateNew();

            var batchSize = 15;

            var testInterval = 10;
            var testSize     = 10;

            var saveInterval = 50;

            var trainer = new AdamTrainer <double>(network.Net)
            {
                LearningRate = 0.001,
                Beta1        = 0.9,
                Beta2        = 0.999
            };

            /*new SgdTrainer<double>(network.Net)
             * {
             *  LearningRate = 0.008,
             *  BatchSize = batchSize,
             *  L2Decay = 0.001,
             *  Momentum = 0.8
             * };*/

            var trainingData = TrainingSet.FromDirectory(Path.Combine(trainingSet, "rotated"), batchSize);
            var testData     = TrainingSet.FromDirectory(Path.Combine(testSet, "rotated"), testSize);

            Log.Info("Training ...");

            WriteHeader();

            var run = 0;

            do
            {
                if (++run % saveInterval == 0)
                {
                    Save(network, model);
                }

                // train
                using (var batch = trainingData.GetBatch())
                {
                    trainer.Train(batch.InputVolume, batch.OutputVolume);
                }

                if (run % testInterval == 0)
                {
                    using (var testBatch = testData.GetBatch())
                    {
                        // and test
                        var result = network.Net.Forward(testBatch.InputVolume);

                        testBatch.SetResult(result);

                        // evaluate results
                        Log.Info($"{trainingData.Epoch}\t\t{run}\t{testBatch.TotalError:0.00}\t{testBatch.MinimumAngle:0.00}\t{testBatch.MaximumAngle:0.00}\t{testBatch.MaxError:0.00}\t{trainer.Loss:0.00000}");
                    }
                }
            }while (!Console.KeyAvailable);

            Save(network, model);
        }