示例#1
0
文件: Training.cs 项目: fdafadf/main
        public Task <TrainingResult> Train(bool liveTracking)
        {
            return(Task.Run(() =>
            {
                int inputSize = NetworkSampleInput.InputSize(NumberOfVertices);
                Random random = new Random(Seed);
                Network network = NetworkBuilder.Build(new NetworkDefinition(FunctionName.ReLU, inputSize, NumberOfVertices, HiddenLayers), new He(Seed));
                Optimizer optimizer = new SGDMomentum(network, LearningRate, Momentum);
                Trainer trainer = new Trainer(optimizer, random);

                GraphPath[] trainingGraphs = GraphGenerator.Generate(random, NumberOfGraphs, NumberOfVertices).Select(Dijkstra.Find).ToArray();
                NetworkSample[] trainingData = NetworkSampleGenerator.Generate(trainingGraphs).ToArray();

                MeanSquareErrorMonitor monitor = new MeanSquareErrorMonitor();
                trainer.Monitors.Add(monitor);

                if (liveTracking)
                {
                    trainer.Monitors.Add(new TrainingMonitor(Name));
                }

                trainer.Train(trainingData, Epoches, BatchSize);

                lock (ThreadLock)
                {
                    double error = MeanSquareErrorMonitor.CalculateError(network, trainingData);
                    Console.WriteLine($"[{Name}] Seed: {Seed,12}  Momentum: {Momentum,-6}  Learning Rate: {LearningRate,-6}  MSE: {error}");
                }

                return new TrainingResult(Name, monitor.CollectedData as IEnumerable <double>, network, trainingGraphs, trainingData);
            }));
        }
示例#2
0
        private void backgroundWorker1_DoWork(object sender, DoWorkEventArgs e)
        {
            bitmap = new Bitmap(200, 200);
            Random    random = new Random(0);
            Optimizer optimizer;
            Network   network;

            if (ActivationFunction == Function.ReLU)
            {
                //var network = NetworkBuilder.Build(Function.ReLU, 2, 3, 64, 64, 32, 32);
                //evaluator = new NetworkEvaluator(network);
                //optimizer = new SGDMomentum(evaluator, 0.001, 0.04);
                network   = NetworkBuilder.Build(new NetworkDefinition(FunctionName.ReLU, 2, 3, 72, 72, 72, 36, 36, 36, 18, 18), new He(0));
                optimizer = new SGDMomentum(network, 0.001, 0.008);
            }
            else
            {
                network   = NetworkBuilder.Build(new NetworkDefinition(FunctionName.Sigmoidal, 2, 3, 32, 8), new He(0));
                optimizer = new SGDMomentum(network, 0.1, 0.8);
            }

            var trainer    = new Trainer(optimizer, new Random(0));
            var mseMonitor = new MeanSquareErrorMonitor();

            trainer.Monitors.Add(mseMonitor);
            int width;
            int height;

            lock (bitmap)
            {
                width  = bitmap.Width;
                height = bitmap.Height;
            }

            var imagePointList = new List <double[]>();

            for (int y = 0; y < height; y++)
            {
                for (int x = 0; x < width; x++)
                {
                    imagePointList.Add(new double[] { (double)x / width, (double)y / height });
                }
            }

            trainingData = GenerateTrainingData(random);

            for (int epoch = 0; epoch < 15000; epoch++)
            {
                //if (epoch % 300 == 0)
                //{
                //    GenerateTrainingData(random);
                //}

                if (epoch % 100 == 0)
                {
                    RefreshBitmap(network, width, height, imagePointList);
                }

                trainer.Train(trainingData, 1, 1);
                Console.WriteLine(mseMonitor.CollectedData.Last());
            }
        }