Ejemplo n.º 1
0
        public Task <TrainingResult> Train(bool liveTracking)
        {
            return(Task.Run(() =>
            {
                int inputSize = NetworkSampleInput.InputSize(NumberOfVertices);
                Random random = new Random(Seed);
                Network network = NetworkBuilder.Build(new NetworkDefinition(FunctionName.ReLU, inputSize, NumberOfVertices, HiddenLayers), new He(Seed));
                Optimizer optimizer = new SGDMomentum(network, LearningRate, Momentum);
                Trainer trainer = new Trainer(optimizer, random);

                GraphPath[] trainingGraphs = GraphGenerator.Generate(random, NumberOfGraphs, NumberOfVertices).Select(Dijkstra.Find).ToArray();
                NetworkSample[] trainingData = NetworkSampleGenerator.Generate(trainingGraphs).ToArray();

                MeanSquareErrorMonitor monitor = new MeanSquareErrorMonitor();
                trainer.Monitors.Add(monitor);

                if (liveTracking)
                {
                    trainer.Monitors.Add(new TrainingMonitor(Name));
                }

                trainer.Train(trainingData, Epoches, BatchSize);

                lock (ThreadLock)
                {
                    double error = MeanSquareErrorMonitor.CalculateError(network, trainingData);
                    Console.WriteLine($"[{Name}] Seed: {Seed,12}  Momentum: {Momentum,-6}  Learning Rate: {LearningRate,-6}  MSE: {error}");
                }

                return new TrainingResult(Name, monitor.CollectedData as IEnumerable <double>, network, trainingGraphs, trainingData);
            }));
        }
Ejemplo n.º 2
0
        public void Fit(IEnumerable <MarkovHistoryItem> batch, AgentNetworkTrainingConfiguration configuration, Random random)
        {
            var optimizer = new SGDMomentum(Network, configuration.LearningRate, configuration.Momentum);
            var trainer   = new Trainer(optimizer, random);
            var nextQ     = batch.Select(item => new Projection(item.Input, new double[] { item.Reward + configuration.Gamma * Predict(item.State).Value })).ToArray();

            trainer.Train(nextQ, configuration.EpochesPerIteration, configuration.BatchSize);
        }
Ejemplo n.º 3
0
        public TicTacToeValueNetwork Train(int epoches, params TrainingMonitor[] monitors)
        {
            var network   = new TicTacToeValueNetwork(HiddenLayerSizes, ActivationFunction, new Random(Seed));
            var optimizer = new SGDMomentum(network.Network, LearingRate, Momentum);
            var trainer   = new Trainer(optimizer, new Random(Seed));

            trainer.Monitors.AddRange(monitors);
            trainer.Train(TrainingData, epoches, 16);
            return(network);
        }
Ejemplo n.º 4
0
        private void backgroundWorker1_DoWork(object sender, DoWorkEventArgs e)
        {
            bitmap = new Bitmap(200, 200);
            Random    random = new Random(0);
            Optimizer optimizer;
            Network   network;

            if (ActivationFunction == Function.ReLU)
            {
                //var network = NetworkBuilder.Build(Function.ReLU, 2, 3, 64, 64, 32, 32);
                //evaluator = new NetworkEvaluator(network);
                //optimizer = new SGDMomentum(evaluator, 0.001, 0.04);
                network   = NetworkBuilder.Build(new NetworkDefinition(FunctionName.ReLU, 2, 3, 72, 72, 72, 36, 36, 36, 18, 18), new He(0));
                optimizer = new SGDMomentum(network, 0.001, 0.008);
            }
            else
            {
                network   = NetworkBuilder.Build(new NetworkDefinition(FunctionName.Sigmoidal, 2, 3, 32, 8), new He(0));
                optimizer = new SGDMomentum(network, 0.1, 0.8);
            }

            var trainer    = new Trainer(optimizer, new Random(0));
            var mseMonitor = new MeanSquareErrorMonitor();

            trainer.Monitors.Add(mseMonitor);
            int width;
            int height;

            lock (bitmap)
            {
                width  = bitmap.Width;
                height = bitmap.Height;
            }

            var imagePointList = new List <double[]>();

            for (int y = 0; y < height; y++)
            {
                for (int x = 0; x < width; x++)
                {
                    imagePointList.Add(new double[] { (double)x / width, (double)y / height });
                }
            }

            trainingData = GenerateTrainingData(random);

            for (int epoch = 0; epoch < 15000; epoch++)
            {
                //if (epoch % 300 == 0)
                //{
                //    GenerateTrainingData(random);
                //}

                if (epoch % 100 == 0)
                {
                    RefreshBitmap(network, width, height, imagePointList);
                }

                trainer.Train(trainingData, 1, 1);
                Console.WriteLine(mseMonitor.CollectedData.Last());
            }
        }