/// <summary>
        /// Trains a graph for a fixed number of iterations
        /// </summary>
        /// <param name="engine">The graph training engine</param>
        /// <param name="numIterations">The number of iterations to train for</param>
        /// <param name="testData">The test data source to use</param>
        /// <param name="errorMetric">The error metric to evaluate the test data against</param>
        /// <param name="onImprovement">Optional callback for when the test data score has improved against the error metric</param>
        /// <param name="testCadence">Determines how many epochs elapse before the test data is evaluated</param>
        public static void Train(this IGraphTrainingEngine engine, int numIterations, IDataSource testData, IErrorMetric errorMetric, Action <GraphModel> onImprovement = null, int testCadence = 1)
        {
            var executionContext = new ExecutionContext(engine.LinearAlgebraProvider);

            engine.Test(testData, errorMetric, 128, percentage => Console.Write("\rTesting... ({0:P})    ", percentage));
            int count = 0;

            for (var i = 0; i < numIterations; i++)
            {
                engine.Train(executionContext, percentage => Console.Write("\rTraining... ({0:P})    ", percentage));
                if (++count == testCadence)
                {
                    if (engine.Test(testData, errorMetric, 128, percentage => Console.Write("\rTesting... ({0:P})    ", percentage)) && onImprovement != null)
                    {
                        var bestModel = new GraphModel {
                            Graph = engine.Graph
                        };
                        if (engine.DataSource is IAdaptiveDataSource adaptiveDataSource)
                        {
                            bestModel.DataSource = adaptiveDataSource.GetModel();
                        }
                        onImprovement(bestModel);
                    }
                    count = 0;
                }
            }
        }
Esempio n. 2
0
 /// <summary>
 /// Connects new nodes to the engine output node
 /// </summary>
 /// <param name="factory">Graph factory</param>
 /// <param name="engine">Graph engine</param>
 /// <param name="inputIndex">Input index to connect</param>
 public WireBuilder(GraphFactory factory, IGraphTrainingEngine engine, int inputIndex = 0) :
     this(factory, engine.DataSource.InputSize, engine.GetInput(inputIndex))
 {
     if (engine.DataSource is IVolumeDataSource volumeDataSource)
     {
         _width  = volumeDataSource.Width;
         _height = volumeDataSource.Height;
         _depth  = volumeDataSource.Depth;
     }
 }
Esempio n. 3
0
        public static IGraphTrainingEngine LoadTrainingNetwork(string path, GraphFactory graph, NetworkConfig config,
                                                               DataSet dataset)
        {
            IGraphTrainingEngine engine = null;

            using (var file = new FileStream(path, FileMode.Open, FileAccess.Read))
            {
                var model = Serializer.Deserialize <GraphModel>(file);
                engine = graph.CreateTrainingEngine(dataset.TrainData, model.Graph, config.LEARNING_RATE, config.BATCH_SIZE);
            }
            return(engine);
        }
Esempio n. 4
0
 public static ExecutionGraph TrainModel(IGraphTrainingEngine engine, NetworkConfig config, DataSet dataset, string outputModelPath)
 {
     BrightWire.Models.ExecutionGraph bestGraph = null;
     engine.Train(config.TRAINING_ITERATIONS, dataset.TestData, config.ERROR_METRIC, model => {
         bestGraph = model.Graph;
         if (!String.IsNullOrWhiteSpace(outputModelPath))
         {
             using (var file = new FileStream(outputModelPath, FileMode.Create, FileAccess.Write))
             {
                 Serializer.Serialize(file, model);
             }
         }
     });
     return(bestGraph);
 }
Esempio n. 5
0
 public static GraphFactory CreateStandardNetwork(IGraphTrainingEngine engine, GraphFactory graph, NetworkConfig config,
                                                  DataSet dataset)
 {
     graph.Connect(engine)
     .AddConvolutional(filterCount: 16, padding: 2, filterWidth: 5, filterHeight: 5, stride: 1, shouldBackpropagate: false)
     .Add(graph.LeakyReluActivation())
     .AddMaxPooling(filterWidth: 2, filterHeight: 2, stride: 2)
     .AddConvolutional(filterCount: 32, padding: 2, filterWidth: 5, filterHeight: 5, stride: 1)
     .Add(graph.LeakyReluActivation())
     .AddMaxPooling(filterWidth: 2, filterHeight: 2, stride: 2)
     .Transpose()
     .AddFeedForward(config.HIDDEN_LAYER_SIZE)
     .Add(graph.LeakyReluActivation())
     .AddDropOut(dropOutPercentage: 0.5f)
     .AddFeedForward(dataset.TrainData.OutputSize)
     .Add(graph.SoftMaxActivation())
     .AddBackpropagation(config.ERROR_METRIC);
     return(graph);
 }
 /// <summary>
 /// Builds a new wire from the engine's input node
 /// </summary>
 /// <param name="engine">Graph engine to build with</param>
 /// <param name="inputIndex">Input index to connect to</param>
 /// <returns></returns>
 public WireBuilder Connect(IGraphTrainingEngine engine, int inputIndex = 0)
 {
     return(new WireBuilder(this, engine));
 }