/// <summary>
        /// Creates a node and any other connected nodes from a serialised execution graph
        /// </summary>
        /// <param name="factory"></param>
        /// <param name="graph">Serialised graph</param>
        public static INode CreateFrom(this GraphFactory factory, Models.ExecutionGraph graph)
        {
            // create the input node
            var nodeTable = new Dictionary <string, INode>();
            var ret       = factory.Create(graph.InputNode);

            nodeTable.Add(ret.Id, ret);

            // create the other nodes
            foreach (var node in graph.OtherNodes)
            {
                var n = factory.Create(node);
                if (!nodeTable.ContainsKey(n.Id))
                {
                    nodeTable.Add(n.Id, n);
                }
            }

            // let each node know it has been deserialised and access to the entire graph
            foreach (var item in nodeTable)
            {
                item.Value.OnDeserialise(nodeTable);
            }

            // create the wires between nodes
            foreach (var wire in graph.Wires)
            {
                var from = nodeTable[wire.FromId];
                var to   = nodeTable[wire.ToId];
                from.Output.Add(new WireToNode(to, wire.InputChannel));
            }
            return(ret);
        }
        public void TestRecurrent()
        {
            var data        = BinaryIntegers.Addition(100, false).Split(0);
            var graph       = new GraphFactory(_lap);
            var errorMetric = graph.ErrorMetric.BinaryClassification;

            graph.CurrentPropertySet
            .Use(graph.GradientDescent.Adam)
            .Use(graph.GaussianWeightInitialisation(false, 0.1f, GaussianVarianceCalibration.SquareRoot2N));

            // create the engine
            var trainingData = graph.CreateDataSource(data.Training);
            var testData     = trainingData.CloneWith(data.Test);
            var engine       = graph.CreateTrainingEngine(trainingData, learningRate: 0.01f, batchSize: 16);

            // build the network
            const int HIDDEN_LAYER_SIZE = 32, TRAINING_ITERATIONS = 5;
            var       memory  = new float[HIDDEN_LAYER_SIZE];
            var       network = graph.Connect(engine)
                                .AddSimpleRecurrent(graph.ReluActivation(), memory)
                                .AddFeedForward(engine.DataSource.OutputSize)
                                .Add(graph.ReluActivation())
                                .AddBackpropagationThroughTime(errorMetric)
            ;

            // train the network for twenty iterations, saving the model on each improvement
            BrightWire.Models.ExecutionGraph bestGraph = null;
            engine.Train(TRAINING_ITERATIONS, testData, errorMetric, bn => bestGraph = bn.Graph);

            // export the graph and verify it against some unseen integers on the best model
            var executionEngine = graph.CreateEngine(bestGraph ?? engine.Graph);
            var testData2       = graph.CreateDataSource(BinaryIntegers.Addition(8, true));
            var results         = executionEngine.Execute(testData2);
        }
示例#3
0
 public static ExecutionGraph TrainModel(IGraphTrainingEngine engine, NetworkConfig config, DataSet dataset, string outputModelPath)
 {
     BrightWire.Models.ExecutionGraph bestGraph = null;
     engine.Train(config.TRAINING_ITERATIONS, dataset.TestData, config.ERROR_METRIC, model => {
         bestGraph = model.Graph;
         if (!String.IsNullOrWhiteSpace(outputModelPath))
         {
             using (var file = new FileStream(outputModelPath, FileMode.Create, FileAccess.Write))
             {
                 Serializer.Serialize(file, model);
             }
         }
     });
     return(bestGraph);
 }