/// <summary> /// Trains a graph for a fixed number of iterations /// </summary> /// <param name="engine">The graph training engine</param> /// <param name="numIterations">The number of iterations to train for</param> /// <param name="testData">The test data source to use</param> /// <param name="errorMetric">The error metric to evaluate the test data against</param> /// <param name="onImprovement">Optional callback for when the test data score has improved against the error metric</param> /// <param name="testCadence">Determines how many epochs elapse before the test data is evaluated</param> public static void Train(this IGraphTrainingEngine engine, int numIterations, IDataSource testData, IErrorMetric errorMetric, Action <GraphModel> onImprovement = null, int testCadence = 1) { var executionContext = new ExecutionContext(engine.LinearAlgebraProvider); engine.Test(testData, errorMetric, 128, percentage => Console.Write("\rTesting... ({0:P}) ", percentage)); int count = 0; for (var i = 0; i < numIterations; i++) { engine.Train(executionContext, percentage => Console.Write("\rTraining... ({0:P}) ", percentage)); if (++count == testCadence) { if (engine.Test(testData, errorMetric, 128, percentage => Console.Write("\rTesting... ({0:P}) ", percentage)) && onImprovement != null) { var bestModel = new GraphModel { Graph = engine.Graph }; if (engine.DataSource is IAdaptiveDataSource adaptiveDataSource) { bestModel.DataSource = adaptiveDataSource.GetModel(); } onImprovement(bestModel); } count = 0; } } }
/// <summary> /// Connects new nodes to the engine output node /// </summary> /// <param name="factory">Graph factory</param> /// <param name="engine">Graph engine</param> /// <param name="inputIndex">Input index to connect</param> public WireBuilder(GraphFactory factory, IGraphTrainingEngine engine, int inputIndex = 0) : this(factory, engine.DataSource.InputSize, engine.GetInput(inputIndex)) { if (engine.DataSource is IVolumeDataSource volumeDataSource) { _width = volumeDataSource.Width; _height = volumeDataSource.Height; _depth = volumeDataSource.Depth; } }
public static IGraphTrainingEngine LoadTrainingNetwork(string path, GraphFactory graph, NetworkConfig config, DataSet dataset) { IGraphTrainingEngine engine = null; using (var file = new FileStream(path, FileMode.Open, FileAccess.Read)) { var model = Serializer.Deserialize <GraphModel>(file); engine = graph.CreateTrainingEngine(dataset.TrainData, model.Graph, config.LEARNING_RATE, config.BATCH_SIZE); } return(engine); }
public static ExecutionGraph TrainModel(IGraphTrainingEngine engine, NetworkConfig config, DataSet dataset, string outputModelPath) { BrightWire.Models.ExecutionGraph bestGraph = null; engine.Train(config.TRAINING_ITERATIONS, dataset.TestData, config.ERROR_METRIC, model => { bestGraph = model.Graph; if (!String.IsNullOrWhiteSpace(outputModelPath)) { using (var file = new FileStream(outputModelPath, FileMode.Create, FileAccess.Write)) { Serializer.Serialize(file, model); } } }); return(bestGraph); }
public static GraphFactory CreateStandardNetwork(IGraphTrainingEngine engine, GraphFactory graph, NetworkConfig config, DataSet dataset) { graph.Connect(engine) .AddConvolutional(filterCount: 16, padding: 2, filterWidth: 5, filterHeight: 5, stride: 1, shouldBackpropagate: false) .Add(graph.LeakyReluActivation()) .AddMaxPooling(filterWidth: 2, filterHeight: 2, stride: 2) .AddConvolutional(filterCount: 32, padding: 2, filterWidth: 5, filterHeight: 5, stride: 1) .Add(graph.LeakyReluActivation()) .AddMaxPooling(filterWidth: 2, filterHeight: 2, stride: 2) .Transpose() .AddFeedForward(config.HIDDEN_LAYER_SIZE) .Add(graph.LeakyReluActivation()) .AddDropOut(dropOutPercentage: 0.5f) .AddFeedForward(dataset.TrainData.OutputSize) .Add(graph.SoftMaxActivation()) .AddBackpropagation(config.ERROR_METRIC); return(graph); }
/// <summary> /// Builds a new wire from the engine's input node /// </summary> /// <param name="engine">Graph engine to build with</param> /// <param name="inputIndex">Input index to connect to</param> /// <returns></returns> public WireBuilder Connect(IGraphTrainingEngine engine, int inputIndex = 0) { return(new WireBuilder(this, engine)); }