/// <summary>
        /// Parses the nodes to determine the validity of the graph and
        /// to determine the inputs and outputs of the graph.
        /// </summary>
        public void Compile()
        {
            _env.Check(_graph == null, "Multiple calls to " + nameof(Compile) + "() detected.");
            var nodes = GetNodes();

            _graph = new EntryPointGraph(_env, _catalog, nodes);
        }
Exemple #2
0
        public GraphRunner(IHostEnvironment env, JArray nodes)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.CheckValue(nodes, nameof(nodes));

            _graph = new EntryPointGraph(_host, nodes);
        }
Exemple #3
0
 public GraphRunner(IHostEnvironment env, EntryPointGraph graph)
 {
     Contracts.CheckValue(env, nameof(env));
     _host  = env.Register(RegistrationName);
     _graph = graph;
 }
Exemple #4
0
        public void PipelineSweeperSerialization()
        {
            // Get datasets
            var          pathData        = GetDataPath("adult.train");
            var          pathDataTest    = GetDataPath("adult.test");
            const int    numOfSampleRows = 1000;
            int          numIterations   = 10;
            const string schema          =
                "sep=, col=Features:R4:0,2,4,10-12 col=workclass:TX:1 col=education:TX:3 col=marital_status:TX:5 col=occupation:TX:6 " +
                "col=relationship:TX:7 col=ethnicity:TX:8 col=sex:TX:9 col=native_country:TX:13 col=label_IsOver50K_:R4:14 header=+";
            var inputFileTrain = new SimpleFileHandle(Env, pathData, false, false);

#pragma warning disable 0618
            var datasetTrain = ImportTextData.ImportText(Env,
                                                         new ImportTextData.Input {
                InputFile = inputFileTrain, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
            var inputFileTest = new SimpleFileHandle(Env, pathDataTest, false, false);
            var datasetTest   = ImportTextData.ImportText(Env,
                                                          new ImportTextData.Input {
                InputFile = inputFileTest, CustomSchema = schema
            }).Data.Take(numOfSampleRows);
#pragma warning restore 0618

            // Define entrypoint graph
            string inputGraph = @"
                {
                  'Nodes': [
                    {
                      'Name': 'Models.PipelineSweeper',
                      'Inputs': {
                        'TrainingData': '$TrainingData',
                        'TestingData': '$TestingData',
                        'StateArguments': {
                            'Name': 'AutoMlState',
                            'Settings': {
                                'Metric': 'Auc',
                                'Engine': {
                                    'Name': 'UniformRandom'
                                },
                                'TerminatorArgs': {
                                    'Name': 'IterationLimited',
                                    'Settings': {
                                        'FinalHistoryLength': 10
                                    }
                                },
                                'TrainerKind': 'SignatureBinaryClassifierTrainer'
                            }
                        },
                        'BatchSize': 5
                      },
                      'Outputs': {
                        'State': '$StateOut',
                        'Results': '$ResultsOut'
                      }
                    },
                  ]
                }";

            JObject graphJson = JObject.Parse(inputGraph);
            var     catalog   = Env.ComponentCatalog;
            var     graph     = new EntryPointGraph(Env, catalog, graphJson[FieldNames.Nodes] as JArray);
            // Test if ToJson() works properly.
            var nodes  = new JArray(graph.AllNodes.Select(node => node.ToJson()));
            var runner = new GraphRunner(Env, catalog, nodes);
            runner.SetInput("TrainingData", datasetTrain);
            runner.SetInput("TestingData", datasetTest);
            runner.RunAll();

            var results = runner.GetOutput <IDataView>("ResultsOut");
            Assert.NotNull(results);
            var rows = PipelinePattern.ExtractResults(Env, results,
                                                      "Graph", "MetricValue", "PipelineId", "TrainingMetricValue", "FirstInput", "PredictorModel");
            Assert.True(rows.Length == numIterations);
        }
 public void Reset()
 {
     _graph = null;
     _jsonNodes.Clear();
 }