/// <summary> /// Creates a node and any other connected nodes from a serialised execution graph /// </summary> /// <param name="factory"></param> /// <param name="graph">Serialised graph</param> public static INode CreateFrom(this GraphFactory factory, Models.ExecutionGraph graph) { // create the input node var nodeTable = new Dictionary <string, INode>(); var ret = factory.Create(graph.InputNode); nodeTable.Add(ret.Id, ret); // create the other nodes foreach (var node in graph.OtherNodes) { var n = factory.Create(node); if (!nodeTable.ContainsKey(n.Id)) { nodeTable.Add(n.Id, n); } } // let each node know it has been deserialised and access to the entire graph foreach (var item in nodeTable) { item.Value.OnDeserialise(nodeTable); } // create the wires between nodes foreach (var wire in graph.Wires) { var from = nodeTable[wire.FromId]; var to = nodeTable[wire.ToId]; from.Output.Add(new WireToNode(to, wire.InputChannel)); } return(ret); }
public void TestRecurrent() { var data = BinaryIntegers.Addition(100, false).Split(0); var graph = new GraphFactory(_lap); var errorMetric = graph.ErrorMetric.BinaryClassification; graph.CurrentPropertySet .Use(graph.GradientDescent.Adam) .Use(graph.GaussianWeightInitialisation(false, 0.1f, GaussianVarianceCalibration.SquareRoot2N)); // create the engine var trainingData = graph.CreateDataSource(data.Training); var testData = trainingData.CloneWith(data.Test); var engine = graph.CreateTrainingEngine(trainingData, learningRate: 0.01f, batchSize: 16); // build the network const int HIDDEN_LAYER_SIZE = 32, TRAINING_ITERATIONS = 5; var memory = new float[HIDDEN_LAYER_SIZE]; var network = graph.Connect(engine) .AddSimpleRecurrent(graph.ReluActivation(), memory) .AddFeedForward(engine.DataSource.OutputSize) .Add(graph.ReluActivation()) .AddBackpropagationThroughTime(errorMetric) ; // train the network for twenty iterations, saving the model on each improvement BrightWire.Models.ExecutionGraph bestGraph = null; engine.Train(TRAINING_ITERATIONS, testData, errorMetric, bn => bestGraph = bn.Graph); // export the graph and verify it against some unseen integers on the best model var executionEngine = graph.CreateEngine(bestGraph ?? engine.Graph); var testData2 = graph.CreateDataSource(BinaryIntegers.Addition(8, true)); var results = executionEngine.Execute(testData2); }
public static ExecutionGraph TrainModel(IGraphTrainingEngine engine, NetworkConfig config, DataSet dataset, string outputModelPath) { BrightWire.Models.ExecutionGraph bestGraph = null; engine.Train(config.TRAINING_ITERATIONS, dataset.TestData, config.ERROR_METRIC, model => { bestGraph = model.Graph; if (!String.IsNullOrWhiteSpace(outputModelPath)) { using (var file = new FileStream(outputModelPath, FileMode.Create, FileAccess.Write)) { Serializer.Serialize(file, model); } } }); return(bestGraph); }