/// <summary> /// Parses the nodes to determine the validity of the graph and /// to determine the inputs and outputs of the graph. /// </summary> public void Compile() { _env.Check(_graph == null, "Multiple calls to " + nameof(Compile) + "() detected."); var nodes = GetNodes(); _graph = new EntryPointGraph(_env, _catalog, nodes); }
public GraphRunner(IHostEnvironment env, JArray nodes) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); _host.CheckValue(nodes, nameof(nodes)); _graph = new EntryPointGraph(_host, nodes); }
public GraphRunner(IHostEnvironment env, EntryPointGraph graph) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); _graph = graph; }
public void PipelineSweeperSerialization() { // Get datasets var pathData = GetDataPath("adult.train"); var pathDataTest = GetDataPath("adult.test"); const int numOfSampleRows = 1000; int numIterations = 10; const string schema = "sep=, col=Features:R4:0,2,4,10-12 col=workclass:TX:1 col=education:TX:3 col=marital_status:TX:5 col=occupation:TX:6 " + "col=relationship:TX:7 col=ethnicity:TX:8 col=sex:TX:9 col=native_country:TX:13 col=label_IsOver50K_:R4:14 header=+"; var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false); #pragma warning disable 0618 var datasetTrain = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFileTrain, CustomSchema = schema }).Data.Take(numOfSampleRows); var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false); var datasetTest = ImportTextData.ImportText(Env, new ImportTextData.Input { InputFile = inputFileTest, CustomSchema = schema }).Data.Take(numOfSampleRows); #pragma warning restore 0618 // Define entrypoint graph string inputGraph = @" { 'Nodes': [ { 'Name': 'Models.PipelineSweeper', 'Inputs': { 'TrainingData': '$TrainingData', 'TestingData': '$TestingData', 'StateArguments': { 'Name': 'AutoMlState', 'Settings': { 'Metric': 'Auc', 'Engine': { 'Name': 'UniformRandom' }, 'TerminatorArgs': { 'Name': 'IterationLimited', 'Settings': { 'FinalHistoryLength': 10 } }, 'TrainerKind': 'SignatureBinaryClassifierTrainer' } }, 'BatchSize': 5 }, 'Outputs': { 'State': '$StateOut', 'Results': '$ResultsOut' } }, ] }"; JObject graphJson = JObject.Parse(inputGraph); var catalog = Env.ComponentCatalog; var graph = new EntryPointGraph(Env, catalog, graphJson[FieldNames.Nodes] as JArray); // Test if ToJson() works properly. var nodes = new JArray(graph.AllNodes.Select(node => node.ToJson())); var runner = new GraphRunner(Env, catalog, nodes); runner.SetInput("TrainingData", datasetTrain); runner.SetInput("TestingData", datasetTest); runner.RunAll(); var results = runner.GetOutput <IDataView>("ResultsOut"); Assert.NotNull(results); var rows = PipelinePattern.ExtractResults(Env, results, "Graph", "MetricValue", "PipelineId", "TrainingMetricValue", "FirstInput", "PredictorModel"); Assert.True(rows.Length == numIterations); }
public void Reset() { _graph = null; _jsonNodes.Clear(); }