public Task <TrainingResult> Train(bool liveTracking) { return(Task.Run(() => { int inputSize = NetworkSampleInput.InputSize(NumberOfVertices); Random random = new Random(Seed); Network network = NetworkBuilder.Build(new NetworkDefinition(FunctionName.ReLU, inputSize, NumberOfVertices, HiddenLayers), new He(Seed)); Optimizer optimizer = new SGDMomentum(network, LearningRate, Momentum); Trainer trainer = new Trainer(optimizer, random); GraphPath[] trainingGraphs = GraphGenerator.Generate(random, NumberOfGraphs, NumberOfVertices).Select(Dijkstra.Find).ToArray(); NetworkSample[] trainingData = NetworkSampleGenerator.Generate(trainingGraphs).ToArray(); MeanSquareErrorMonitor monitor = new MeanSquareErrorMonitor(); trainer.Monitors.Add(monitor); if (liveTracking) { trainer.Monitors.Add(new TrainingMonitor(Name)); } trainer.Train(trainingData, Epoches, BatchSize); lock (ThreadLock) { double error = MeanSquareErrorMonitor.CalculateError(network, trainingData); Console.WriteLine($"[{Name}] Seed: {Seed,12} Momentum: {Momentum,-6} Learning Rate: {LearningRate,-6} MSE: {error}"); } return new TrainingResult(Name, monitor.CollectedData as IEnumerable <double>, network, trainingGraphs, trainingData); })); }
protected void OnSelectedGraphIndexChanged() { Text = $"{selectedGraphIndex + 1}/{Graphs.Length}"; GraphPath graphPath = Graphs[selectedGraphIndex]; Graph graph = graphPath.Graph; graphPainter = new GraphPainter(graphPath); graphPainter.Scale = Math.Min(ClientSize.Width, ClientSize.Height); dijkstraPathPainter = new GraphPathPainter(graphPainter, graphPath.Path, dijkstraPathPen); var input = new NetworkSampleInput(graphPath.Graph, 0); bool[] visitStatus = new bool[graph.Vertices.Length]; visitStatus[0] = true; int currentVertex = 0; List <int> path = new List <int>(); path.Add(currentVertex); do { visitStatus[currentVertex] = true; input.SetCurrentVertex(currentVertex); double[] prediction = Network.Evaluate(input.Values); currentVertex = prediction.IndexOfMax(0); path.Add(currentVertex); }while (visitStatus[currentVertex] == false && currentVertex != graph.Vertices.Length - 1); networkPathPainter = new GraphPathPainter(graphPainter, path.ToArray(), networkPathPen); Refresh(); }
public static NetworkSample Create(GraphPath graphPath, int currentVertex, int label) { NetworkSampleInput input = new NetworkSampleInput(graphPath.Graph, currentVertex); double[] output = new double[graphPath.Graph.Vertices.Length]; output[label] = 1; return(new NetworkSample(graphPath, currentVertex, input.Values, output)); }