예제 #1
0
파일: Training.cs 프로젝트: fdafadf/main
        public Task <TrainingResult> Train(bool liveTracking)
        {
            return(Task.Run(() =>
            {
                int inputSize = NetworkSampleInput.InputSize(NumberOfVertices);
                Random random = new Random(Seed);
                Network network = NetworkBuilder.Build(new NetworkDefinition(FunctionName.ReLU, inputSize, NumberOfVertices, HiddenLayers), new He(Seed));
                Optimizer optimizer = new SGDMomentum(network, LearningRate, Momentum);
                Trainer trainer = new Trainer(optimizer, random);

                GraphPath[] trainingGraphs = GraphGenerator.Generate(random, NumberOfGraphs, NumberOfVertices).Select(Dijkstra.Find).ToArray();
                NetworkSample[] trainingData = NetworkSampleGenerator.Generate(trainingGraphs).ToArray();

                MeanSquareErrorMonitor monitor = new MeanSquareErrorMonitor();
                trainer.Monitors.Add(monitor);

                if (liveTracking)
                {
                    trainer.Monitors.Add(new TrainingMonitor(Name));
                }

                trainer.Train(trainingData, Epoches, BatchSize);

                lock (ThreadLock)
                {
                    double error = MeanSquareErrorMonitor.CalculateError(network, trainingData);
                    Console.WriteLine($"[{Name}] Seed: {Seed,12}  Momentum: {Momentum,-6}  Learning Rate: {LearningRate,-6}  MSE: {error}");
                }

                return new TrainingResult(Name, monitor.CollectedData as IEnumerable <double>, network, trainingGraphs, trainingData);
            }));
        }
예제 #2
0
파일: Form1.cs 프로젝트: fdafadf/main
        protected void OnSelectedGraphIndexChanged()
        {
            Text = $"{selectedGraphIndex + 1}/{Graphs.Length}";
            GraphPath graphPath = Graphs[selectedGraphIndex];
            Graph     graph     = graphPath.Graph;

            graphPainter        = new GraphPainter(graphPath);
            graphPainter.Scale  = Math.Min(ClientSize.Width, ClientSize.Height);
            dijkstraPathPainter = new GraphPathPainter(graphPainter, graphPath.Path, dijkstraPathPen);

            var input = new NetworkSampleInput(graphPath.Graph, 0);

            bool[] visitStatus = new bool[graph.Vertices.Length];
            visitStatus[0] = true;
            int        currentVertex = 0;
            List <int> path          = new List <int>();

            path.Add(currentVertex);

            do
            {
                visitStatus[currentVertex] = true;
                input.SetCurrentVertex(currentVertex);
                double[] prediction = Network.Evaluate(input.Values);
                currentVertex = prediction.IndexOfMax(0);
                path.Add(currentVertex);
            }while (visitStatus[currentVertex] == false && currentVertex != graph.Vertices.Length - 1);

            networkPathPainter = new GraphPathPainter(graphPainter, path.ToArray(), networkPathPen);
            Refresh();
        }
예제 #3
0
        public static NetworkSample Create(GraphPath graphPath, int currentVertex, int label)
        {
            NetworkSampleInput input = new NetworkSampleInput(graphPath.Graph, currentVertex);

            double[] output = new double[graphPath.Graph.Vertices.Length];
            output[label] = 1;
            return(new NetworkSample(graphPath, currentVertex, input.Values, output));
        }