public static int TestItems(string modelPath, Bitmap testImage) { using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) { var graph = new GraphFactory(lap); DataSet testDataset = BuildTestSet(graph, testImage); var errorMetric = graph.ErrorMetric.OneHotEncoding; var config = new NetworkConfig(); config.ERROR_METRIC = errorMetric; var engine = LoadTestingNetwork(modelPath, graph); var executionEngine = graph.CreateEngine(engine.Graph); var output = executionEngine.Execute(testDataset.TestData); return(GetLargestPercent(output[0])); } }
public static ExecutionGraph TrainModel(IGraphTrainingEngine engine, NetworkConfig config, DataSet dataset, string outputModelPath) { BrightWire.Models.ExecutionGraph bestGraph = null; engine.Train(config.TRAINING_ITERATIONS, dataset.TestData, config.ERROR_METRIC, model => { bestGraph = model.Graph; if (!String.IsNullOrWhiteSpace(outputModelPath)) { using (var file = new FileStream(outputModelPath, FileMode.Create, FileAccess.Write)) { Serializer.Serialize(file, model); } } }); return(bestGraph); }
public static float TrainCNN(string dataFolderPath, string outputModelPath) { using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) { var graph = new GraphFactory(lap); var dataset = CreateDataset(graph, dataFolderPath); var errorMetric = graph.ErrorMetric.OneHotEncoding; var config = new NetworkConfig(); config.ERROR_METRIC = errorMetric; var engine = BuildNetwork(config, graph, dataset, outputModelPath); var bestGraph = TrainModel(engine, config, dataset, outputModelPath); var executionEngine = graph.CreateEngine(bestGraph ?? engine.Graph); var output = executionEngine.Execute(dataset.TestData); return(output.Average(o => o.CalculateError(errorMetric))); } }
public static IGraphTrainingEngine BuildNetwork(NetworkConfig config, GraphFactory graph, DataSet dataset, string outputModelPath = null) { graph.CurrentPropertySet .Use(graph.GradientDescent.Adam) .Use(graph.GaussianWeightInitialisation(config.ZERO_BIAS, config.STANDARD_DEVIATION, config.GAUSSIAN_VARIANCE_CALIBRATION)); var engine = graph.CreateTrainingEngine(dataset.TrainData, config.LEARNING_RATE, config.BATCH_SIZE); if (!String.IsNullOrWhiteSpace(outputModelPath) && File.Exists(outputModelPath)) { engine = LoadTrainingNetwork(outputModelPath, graph, config, dataset); } else { graph = CreateStandardNetwork(engine, graph, config, dataset); } engine.LearningContext.ScheduleLearningRate(15, config.LEARNING_RATE / 2); return(engine); }
public static IGraphTrainingEngine LoadTrainingNetwork(string path, GraphFactory graph, NetworkConfig config, DataSet dataset) { IGraphTrainingEngine engine = null; using (var file = new FileStream(path, FileMode.Open, FileAccess.Read)) { var model = Serializer.Deserialize <GraphModel>(file); engine = graph.CreateTrainingEngine(dataset.TrainData, model.Graph, config.LEARNING_RATE, config.BATCH_SIZE); } return(engine); }
public static GraphFactory CreateStandardNetwork(IGraphTrainingEngine engine, GraphFactory graph, NetworkConfig config, DataSet dataset) { graph.Connect(engine) .AddConvolutional(filterCount: 16, padding: 2, filterWidth: 5, filterHeight: 5, stride: 1, shouldBackpropagate: false) .Add(graph.LeakyReluActivation()) .AddMaxPooling(filterWidth: 2, filterHeight: 2, stride: 2) .AddConvolutional(filterCount: 32, padding: 2, filterWidth: 5, filterHeight: 5, stride: 1) .Add(graph.LeakyReluActivation()) .AddMaxPooling(filterWidth: 2, filterHeight: 2, stride: 2) .Transpose() .AddFeedForward(config.HIDDEN_LAYER_SIZE) .Add(graph.LeakyReluActivation()) .AddDropOut(dropOutPercentage: 0.5f) .AddFeedForward(dataset.TrainData.OutputSize) .Add(graph.SoftMaxActivation()) .AddBackpropagation(config.ERROR_METRIC); return(graph); }