/// <summary> /// Get the current error of training dataset with multitherading /// </summary> /// <returns>Errors of the set</returns> public static double[] GetCurrentErrorMultitherading(this AdamTrainer adam) { double[] Errors = new double[adam.TrainingDataset.Size]; int i = 0; foreach (IOMetaDataSetItem <double[]> item in adam.TrainingDataset) { adam.TargetNetwork.UpdatePositiveMultitherading(item.DataIn); Errors[i] = 0.5 * adam.GetMeanSquareError(adam.TargetNetwork.OutputValues, item.DataOut).Select((double v) => (v * v)).Sum(); i++; } return(Errors); }
/// <summary> /// Training the network with multitherading /// </summary> /// <param name="Iteration">Number of iteration</param> /// <param name="MinimunError">Minimun error of training</param> /// <param name="AutoExit">Auto exit when error less then minimun error</param> /// <returns>A enumerable for errors</returns> public static IEnumerable <double> TrainMultitherading(this AdamTrainer adam, int Iteration, double MinimunError, bool AutoExit) { // A global gradient summary for adam-grad // double[] -- synapse of one neuron // List<double[]> -- neurons of one layer // List<List<double[]>> -- layers of the network List <List <double[]> > MSummary = new List <List <double[]> >(); List <List <double[]> > NSummary = new List <List <double[]> >(); adam.InitializeNetworkArray(ref MSummary); adam.InitializeNetworkArray(ref NSummary); // A global gradient summary for adam-grad // double[] -- synapse of one layer of bias neuron // List<double[]> -- layers of the network List <double[]> biasMSummary = new List <double[]>(); List <double[]> biasNSummary = new List <double[]>(); adam.InitializeBiasArray(ref biasMSummary); adam.InitializeBiasArray(ref biasNSummary); adam.adamDecent.ClearCorrectAccumulate(); for (int iteration = 0; iteration < Iteration; iteration++) { foreach (IOMetaDataSetItem <double[]> item in adam.TrainingDataset) { // A new weight array List <List <double[]> > newWeights = new List <List <double[]> >(); adam.InitializeNetworkArray(ref newWeights); // A new bias array List <double[]> newBias = new List <double[]>(); adam.InitializeBiasArray(ref newBias); // Update the network with sample adam.TargetNetwork.UpdatePositiveMultitherading(item.DataIn); // Traversing all layers double[] nexterrors = new double[0]; for (int layer = adam.TargetNetwork.Layers - 1; layer >= 0; layer--) { // Get number of synapse pre neurons in this layer int SynapseCount = (layer == 0) ? (adam.TargetNetwork.Inputs) : (adam.TargetNetwork.Neurons[layer - 1].Length); // Get the number of neurons in this layer int NeuronsCount = adam.TargetNetwork.Neurons[layer].Length; // Get the error from the back layer (backpropagation) // When start (in the final layer), the error is network error double[] errors = new double[NeuronsCount]; if (layer == adam.TargetNetwork.Layers - 1) { errors = adam.GetMeanSquareError(adam.TargetNetwork.OutputValues, item.DataOut); // Start, final layer } else { errors = nexterrors; } // Create a new error array for next layer nexterrors = new double[SynapseCount]; Array.Fill(nexterrors, 0); // Create a variable to storage bias errors // Traversing all the neurons in this layer Parallel.For(0, NeuronsCount, delegate(int neuron) { Neuron n = adam.TargetNetwork.Neurons[layer][neuron]; // Traversing all synapse in this neuron and calculate the gradients double[] weights = n.Weights; double[] gradients = new double[SynapseCount]; double[] msummary = MSummary[layer][neuron]; double[] nsummary = NSummary[layer][neuron]; for (int synapse = 0; synapse < SynapseCount; synapse++) { double prev = (layer == 0) ? (adam.TargetNetwork.InputNeurons[synapse].OutputValue) : (adam.TargetNetwork.Neurons[layer - 1][synapse].OutputValue); // Calculate gradient => g = - E * d(f(Is))/d(Is) * Os double gradient = -errors[neuron] * n.TransferFunction.Derivatives(n.MiddleValue) * prev; // Check the gradient to avoid NaN values if (double.IsNaN(gradient)) { throw new Exception("Gradient is NaN"); } // Save the gradient gradients[synapse] = gradient; // Statistic next error => next error = Σ(Ws * e) nexterrors[synapse] += adam.TargetNetwork.Neurons[layer][neuron].Weights[synapse] * errors[neuron]; } // Gradient decent double[] biasunused = new double[0]; object biasparameter = new List <double[]>() { msummary, nsummary }; adam.adamDecent.Update(ref weights, gradients, ref biasunused, null, ref biasparameter); // Update the summary MSummary[layer][neuron] = (biasparameter as List <double[]>)[0]; NSummary[layer][neuron] = (biasparameter as List <double[]>)[1]; // Save the new weights newWeights[layer][neuron] = weights; }); // Now calculate the bias, but only apart from the last layer have bias // The bias[a] is actually connect to layer[b] // We had get the errors from neurons calculation double[] bias = adam.TargetNetwork.Neurons[layer].Select((Neuron p) => (p.Bias)).ToArray(); double[] biasGradients = new double[NeuronsCount]; double[] biasmSummary = biasMSummary[layer]; double[] biasnSummary = biasNSummary[layer]; for (int synapse = 0; synapse < NeuronsCount; synapse++) { // Calculate gradient => g = - E * Os double gradient = -errors[synapse] * adam.TargetNetwork.BiasNeurons[layer].OutputValue; // Save the gradient biasGradients[synapse] = gradient; } // Gradient decent double[] unused = new double[0]; object parameter = new List <double[]>() { biasmSummary, biasnSummary }; adam.adamDecent.Update(ref bias, biasGradients, ref unused, null, ref parameter); // Update the summary biasMSummary[layer] = (parameter as List <double[]>)[0]; biasNSummary[layer] = (parameter as List <double[]>)[1]; // Save the new bias newBias[layer] = bias; } // Update the weights by new weights for (int layer = 0; layer < adam.TargetNetwork.Layers; layer++) { // Get number of synapse pre neurons in this layer int SynapseCount = (layer == 0) ? (adam.TargetNetwork.Inputs) : (adam.TargetNetwork.Neurons[layer - 1].Length); // Get the number of neurons in this layer int NeuronsCount = adam.TargetNetwork.Neurons[layer].Length; // Traversing all the neurons in this layer for (int neuron = 0; neuron < NeuronsCount; neuron++) { adam.TargetNetwork.Neurons[layer][neuron].Weights = newWeights[layer][neuron]; adam.TargetNetwork.Neurons[layer][neuron].Bias = newBias[layer][neuron]; } } // Correct accumulate adam.adamDecent.CorrectAccumulate(); } // Return the deviation of the new iteration double error = adam.GetCurrentErrorMultitherading().Average(); yield return(error); if (error <= MinimunError) { if (AutoExit == true) { break; } } } }
public static void Run() { const string NetName = "net.dat"; var random = new Random(); var entryContainer = new EntryContainer(); #region Load Net var convInputWith = 11; // Will extract 11x11 area if (convInputWith % 2 == 0) { throw new ArgumentException("convInputWith must be odd"); } // Load IA or initialize new network if not found - Direction choice INet singleNet = NetExtension.LoadOrCreateNet(NetName, () => { var net = FluentNet.Create(convInputWith, convInputWith, 3) .Conv(3, 3, 16).Stride(2) .Tanh() .Conv(2, 2, 16) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); #endregion #region Load data var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader int total = hltFiles.Count; Console.WriteLine($"Loading {total} games..."); foreach (var file in hltFiles) { Console.WriteLine(total--); HltReader reader = new HltReader(file); var playerId = -1; var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman")); if (playerToCopy == null) { Console.WriteLine("Player not found"); continue; } playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1; var width = reader.Width; var height = reader.Height; int lastmoveCount = 1; for (var frame = 0; frame < reader.FrameCount - 1; frame++) { var currentFrame = reader.GetFrame(frame); var map = currentFrame.map; var moves = currentFrame.moves; var helper = new Helper(map, (ushort)playerId); bool foundInFrame = false; int moveCount = 0; // moves for (ushort x = 0; x < width; x++) { for (ushort y = 0; y < height; y++) { if (map[x, y].Owner == playerId) { foundInFrame = true; moveCount++; if (random.NextDouble() < 1.0 / lastmoveCount) { var convVolume = map.GetVolume(convInputWith, playerId, x, y); // Input var direction = moves[y][x]; // Output var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode()); entryContainer.Add(entry1); // Data augmentation var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry2); var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry3); var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode()); entryContainer.Add(entry4); } } } } lastmoveCount = moveCount; if (!foundInFrame) { // player has died break; } } } var length = entryContainer.Shuffle(); Console.WriteLine(entryContainer.Summary); #endregion #region Training var trainer = new AdamTrainer(singleNet) { BatchSize = 1024, LearningRate = 0.1, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single"); bool save = true; double lastValidationAcc = 0.0; do { for (int i = 0; i < 1000; i++) { if (i > 5) { trainer.L2Decay = 0.001; } Console.WriteLine($"Epoch #{i + 1}"); if (i % 15 == 0) { trainer.LearningRate = Math.Max(trainer.LearningRate / 10.0, 0.00001); } trainingScheme.RunEpoch(); #region Save Nets if (save) { // Save if validation accuracy has improved if (trainingScheme.ValidationAccuracy > lastValidationAcc) { lastValidationAcc = trainingScheme.ValidationAccuracy; singleNet.SaveNet(NetName); } } #endregion if (Console.KeyAvailable) { break; } } } while (!Console.KeyAvailable); #endregion }
public static void Run() { var random = new Random(RandomUtilities.Seed); int normalInputWidth = 19; int earlyInputWidth = 19; int strongInputWidth = 19; string NetName = $"net.dat"; string NetName_early = $"net_early.dat"; string NetName_strong = $"net_strong.dat"; var entryContainer = new EntryContainer(); var entryContainer_early = new EntryContainer(); var entryContainer_strong = new EntryContainer(); #region Load Net INet singleNet = NetExtension.LoadOrCreateNet(NetName, () => { var net = FluentNet.Create(normalInputWidth, normalInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); INet singleNet_early = NetExtension.LoadOrCreateNet(NetName_early, () => { var net = FluentNet.Create(earlyInputWidth, earlyInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); INet singleNet_strong = NetExtension.LoadOrCreateNet(NetName_strong, () => { var net = FluentNet.Create(strongInputWidth, strongInputWidth, 3) .Conv(5, 5, 16).Stride(5).Pad(2) .Tanh() .Conv(3, 3, 16).Stride(1).Pad(1) .Tanh() .FullyConn(100) .Relu() .FullyConn(5) .Softmax(5).Build(); return(net); }); #endregion #region Load data var hltFiles = Directory.EnumerateFiles(@"..\..\..\games\2609\", "*.hlt").ToList(); // erdman games downloaded with HltDownloader int total = hltFiles.Count; Console.WriteLine($"Loading {total} games..."); foreach (var file in hltFiles) { Console.WriteLine(total--); HltReader reader = new HltReader(file); var playerId = -1; var playerToCopy = reader.PlayerNames.FirstOrDefault(o => o.StartsWith("erdman")); if (playerToCopy != null) { playerId = reader.PlayerNames.IndexOf(playerToCopy) + 1; } if (playerId != -1) { var width = reader.Width; var height = reader.Height; int lastmoveCount = 1; for (var frame = 0; frame < reader.FrameCount - 1; frame++) { bool earlyGame = lastmoveCount < 25; var currentFrame = reader.GetFrame(frame); var map = currentFrame.map; var moves = currentFrame.moves; var helper = new Helper(map, (ushort)playerId); bool foundInFrame = false; int moveCount = 0; // moves for (ushort x = 0; x < width; x++) { for (ushort y = 0; y < height; y++) { if (map[x, y].Owner == playerId) { bool strong = map[x, y].Strength > 200; foundInFrame = true; moveCount++; if ((earlyGame && random.NextDouble() < 1.5 / lastmoveCount) || (strong && random.NextDouble() < 1.5 / lastmoveCount) || random.NextDouble() < 1.0 / lastmoveCount) { var w = normalInputWidth; var container = entryContainer; if (earlyGame) { w = earlyInputWidth; container = entryContainer_early; } else if (strong) { w = strongInputWidth; container = entryContainer_strong; } var convVolume = map.GetVolume(w, playerId, x, y); var direction = moves[y][x]; var entry1 = new Entry(new[] { convVolume }, direction, x, y, frame, file.GetHashCode()); container.Add(entry1); var entry2 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.LeftRight) }, (int)Helper.FlipLeftRight((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry2); var entry3 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.UpDown) }, (int)Helper.FlipUpDown((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry3); var entry4 = new Entry(new[] { convVolume.Flip(VolumeUtilities.FlipMode.Both) }, (int)Helper.FlipBothWay((Direction)direction), x, y, frame, file.GetHashCode()); container.Add(entry4); } } } } lastmoveCount = moveCount; if (!foundInFrame) { // player has died break; } } } else { Console.WriteLine("not found"); } } var length = entryContainer.Shuffle(); Console.WriteLine("normal: " + entryContainer.Summary); length = entryContainer_early.Shuffle(); Console.WriteLine("early: " + entryContainer_early.Summary); length = entryContainer_strong.Shuffle(); Console.WriteLine("strong " + entryContainer_strong.Summary); #endregion #region Training var trainer = new AdamTrainer(singleNet) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme = new TrainingScheme(singleNet, trainer, entryContainer, "single"); var trainer_early = new AdamTrainer(singleNet_early) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme_early = new TrainingScheme(singleNet_early, trainer_early, entryContainer_early, "single_early"); var trainer_strong = new AdamTrainer(singleNet_strong) { BatchSize = 1024, LearningRate = 0.01, Beta1 = 0.9, Beta2 = 0.99, Eps = 1e-8 }; var trainingScheme_strong = new TrainingScheme(singleNet_strong, trainer_strong, entryContainer_strong, "single_strong"); bool save = true; double lastValidationAcc = 0.0; double lastValidationAcc_early = 0.0; double lastValidationAcc_strong = 0.0; double lastTrainAcc = 0.0; double lastTrainAcc_early = 0.0; double lastTrainAcc_strong = 0.0; do { var normal = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer.L2Decay = 0.05; } Console.WriteLine($"[normal] Epoch #{i + 1}"); if (i % 50 == 0) { trainer.LearningRate = Math.Max(trainer.LearningRate / 5.0, 0.00001); } trainingScheme.RunEpoch(); #region Save Nets if (save) { if (trainingScheme.ValidationAccuracy > lastValidationAcc) { lastValidationAcc = trainingScheme.ValidationAccuracy; lastTrainAcc = trainingScheme.TrainAccuracy; singleNet.SaveNet(NetName); } } #endregion if (Console.KeyAvailable) { break; } } }); var early = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer_early.L2Decay = 0.05; } Console.WriteLine($"[early] Epoch #{i + 1}"); if (i % 50 == 0) { trainer_early.LearningRate = Math.Max(trainer_early.LearningRate / 5.0, 0.00001); } trainingScheme_early.RunEpoch(); #region Save Nets if (save) { if (trainingScheme_early.ValidationAccuracy > lastValidationAcc_early) { lastValidationAcc_early = trainingScheme_early.ValidationAccuracy; lastTrainAcc_early = trainingScheme_early.TrainAccuracy; singleNet_early.SaveNet(NetName_early); } } #endregion if (Console.KeyAvailable) { break; } } }); var strong = Task.Factory.StartNew(() => { for (int i = 0; i < 50; i++) { if (i > 5) { trainer_strong.L2Decay = 0.05; } Console.WriteLine($"[strong] Epoch #{i + 1}"); if (i % 50 == 0) { trainer_strong.LearningRate = Math.Max(trainer_strong.LearningRate / 5.0, 0.00001); } trainingScheme_strong.RunEpoch(); #region Save Nets if (save) { if (trainingScheme_strong.ValidationAccuracy > lastValidationAcc_strong) { lastValidationAcc_strong = trainingScheme_strong.ValidationAccuracy; lastTrainAcc_strong = trainingScheme_strong.TrainAccuracy; singleNet_strong.SaveNet(NetName_strong); } } #endregion if (Console.KeyAvailable) { break; } } }); Task.WaitAll(new[] { normal, strong, early }); }while (!Console.KeyAvailable); #endregion }
/// <summary> /// main training module /// </summary> /// <param name="cfg">Model configurations</param> /// <param name="gameState">Game State module with access to game environment and dino</param> /// <param name="observe">Flag to indicate wherther the model is to be trained(weight updates), else just play</param> public void TrainModel(GameModelConfig cfg, GameState gameState, bool justPlay = false) { CacheUtils.InitCache( ("epsilon", cfg.InitialEpsilon), ("time", 0), ("D", new Queue <(Volume <double>, int, double, Volume <double>, bool)>())); //initial variable caching, done only once var model = BuildModel(cfg); var lastTime = DateTime.Now; //store the previous observations in replay memory var D = CacheUtils.LoadObj <Queue <(Volume <double>, int, double, Volume <double>, bool)> >("D"); //load from file system // get the first state by doing nothing var do_nothing = new double[cfg.Actions]; do_nothing[0] = 1; //0 => do nothing, //1 => jump var(x_t, r_0, terminal) = gameState.GetState(do_nothing, cfg.InputWidth, cfg.InputHeight); //get next step after performing the action var s_t = BuilderInstance.Volume.From(x_t.Repeat(4), new Shape(cfg.InputWidth, cfg.InputHeight, cfg.ImageChannels)); //s_t.ReShape(1, cfg.ImageRows, cfg.ImageCols, cfg.ImageChannels); var initial_state = x_t; double observe; double epsilon; model = SerializationExtensions.FromJson <double>(File.ReadAllText(modelPath)); var trainer = new AdamTrainer(model) { LearningRate = cfg.LearningRate }; if (justPlay) { observe = 999999999; //We keep observe, never train epsilon = cfg.FinalEpsilon; } else //We go to training mode { observe = cfg.Observation; epsilon = CacheUtils.LoadObj <double>("epsilon"); } int t = CacheUtils.LoadObj <int>("time"); // resume from the previous time step stored in file system while (true) //endless running { double loss = 0; double Q_sa = 0; int action_index = 0; double r_t = 0; //reward at 4 var a_t = new double[cfg.Actions]; //action at t //choose an action epsilon greedy if (t % cfg.FramePerAction == 0) //parameter to skip frames for actions { if (_random.NextDouble() <= epsilon) //randomly explore an action { _logger.LogInformation("----------Random Action----------"); action_index = _random.Next(cfg.Actions); a_t[action_index] = 1; } else //predict the output { model.Forward(s_t); //input a stack of 4 images, get the prediction var q = model.GetPrediction(); action_index = q[0]; //chosing index with maximum q value a_t[action_index] = 1; //o=> do nothing, 1=> jump } } //We reduced the epsilon (exploration parameter) gradually if (epsilon > cfg.FinalEpsilon && t > observe) { epsilon -= (cfg.InitialEpsilon - cfg.FinalEpsilon) / cfg.Explore; } //run the selected action and observed next state and reward double[] x_t1; (x_t1, r_t, terminal) = gameState.GetState(a_t, cfg.InputWidth, cfg.InputHeight); _logger.LogInformation($"fps: { 1 / (DateTime.Now - lastTime).TotalSeconds }"); //helpful for measuring frame rate lastTime = DateTime.Now; var s_t1 = BuilderInstance.Volume.From(s_t.ToArray().StackAndShift(x_t1), s_t.Shape); //append the new image to input stack and remove the first one //store the transition in D D.Enqueue((s_t, action_index, r_t, s_t1, terminal)); if (D.Count > cfg.ReplayMemory) { D.Dequeue(); } //only train if done observing if (t > observe) { //var minibatch } } }
public void Run(string model, string trainingSet, string testSet) { #if GPU Log.Info("Enabling GPU mode ..."); BuilderInstance <double> .Volume = new ConvNetSharp.Volume.GPU.Double.VolumeBuilder(); #endif Log.Info("Loading model ..."); var network = (File.Exists(model)) ? Network.FromFile(model) : Network.CreateNew(); var batchSize = 15; var testInterval = 10; var testSize = 10; var saveInterval = 50; var trainer = new AdamTrainer <double>(network.Net) { LearningRate = 0.001, Beta1 = 0.9, Beta2 = 0.999 }; /*new SgdTrainer<double>(network.Net) * { * LearningRate = 0.008, * BatchSize = batchSize, * L2Decay = 0.001, * Momentum = 0.8 * };*/ var trainingData = TrainingSet.FromDirectory(Path.Combine(trainingSet, "rotated"), batchSize); var testData = TrainingSet.FromDirectory(Path.Combine(testSet, "rotated"), testSize); Log.Info("Training ..."); WriteHeader(); var run = 0; do { if (++run % saveInterval == 0) { Save(network, model); } // train using (var batch = trainingData.GetBatch()) { trainer.Train(batch.InputVolume, batch.OutputVolume); } if (run % testInterval == 0) { using (var testBatch = testData.GetBatch()) { // and test var result = network.Net.Forward(testBatch.InputVolume); testBatch.SetResult(result); // evaluate results Log.Info($"{trainingData.Epoch}\t\t{run}\t{testBatch.TotalError:0.00}\t{testBatch.MinimumAngle:0.00}\t{testBatch.MaximumAngle:0.00}\t{testBatch.MaxError:0.00}\t{trainer.Loss:0.00000}"); } } }while (!Console.KeyAvailable); Save(network, model); }