Example #1
0
        static void Main(string[] args)
        {
            MLContext mlContext     = new MLContext();
            IDataView trainDataView = mlContext.Data.LoadFromTextFile <TrafficData>(GetAbsolutePath("../../../Data/Metro_Interstate_Traffic_Volume.csv"), hasHeader: true, separatorChar: ',');
            //configure experiment settings
            var experimentSettings = new RegressionExperimentSettings();

            experimentSettings.MaxExperimentTimeInSeconds = 10;
            var cts = new CancellationTokenSource();

            experimentSettings.CancellationToken = cts.Token;
            experimentSettings.OptimizingMetric  = RegressionMetric.MeanSquaredError;
            experimentSettings.CacheDirectory    = null;

            // Cancel experiment after the user presses any key
            CancelExperimentAfterAnyKeyPress(cts);
            //create experiment
            RegressionExperiment experiment = mlContext.Auto().CreateRegressionExperiment(experimentSettings);
            var handler = new RegressionExperimentProgressHandler();
            //execute experiment
            ExperimentResult <RegressionMetrics> experimentResult = experiment.Execute(trainDataView, labelColumnName: "Label", progressHandler: handler);
            //Evaluate
            RegressionMetrics metrics = experimentResult.BestRun.ValidationMetrics;

            Console.WriteLine($"Best Algorthm: {experimentResult.BestRun.TrainerName}");
            Console.WriteLine($"R-Squared: {metrics.RSquared:0.##}");
            Console.WriteLine($"Root Mean Squared Error: {metrics.RootMeanSquaredError:0.##}");

            Console.ReadKey();
        }
Example #2
0
        public void Start()
        {
            //Infer columns and load train data
            var columnInferenceResult = mlContext.Auto().InferColumns(
                path: TRAIN_DATA_FILEPATH,
                labelColumnName: "next",
                groupColumns: false);

            TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions);

            trainData = textLoader.Load(TRAIN_DATA_FILEPATH);

            //Modify infered columns information
            columnInformation = columnInferenceResult.ColumnInformation;

            columnInformation.CategoricalColumnNames.Add("productId");
            columnInformation.NumericColumnNames.Remove("productId");

            columnInformation.CategoricalColumnNames.Add("year");
            columnInformation.NumericColumnNames.Remove("year");

            columnInformation.NumericColumnNames.Remove("units");
            columnInformation.IgnoredColumnNames.Add("units");


            var experimentSettings = new RegressionExperimentSettings()
            {
                MaxExperimentTimeInSeconds = 10,
                OptimizingMetric           = RegressionMetric.RootMeanSquaredError,
                CacheDirectory             = new DirectoryInfo(CACHE_DIRECTORY),
                CancellationToken          = cancelationTokenSource.Token
            };

            //Exclude trainers from experiment
            experimentSettings.Trainers.Remove(RegressionTrainer.Ols);

            RegressionExperiment experiment = mlContext.Auto().CreateRegressionExperiment(experimentSettings);
            ExperimentResult <RegressionMetrics> experimentResult = experiment.Execute(
                trainData: trainData,
                columnInformation: columnInformation,
                progressHandler: new RegressionProgressHandler(),
                preFeaturizer: null);

            ITransformer model = experimentResult.BestRun.Model;
            IEstimator <ITransformer> estimator = experimentResult.BestRun.Estimator;

            //Make batch predictions
            IDataView predictionsDataView = model.Transform(trainData);

            PrintPredictions(predictionsDataView);
            PrintPredictionsEnumerable(predictionsDataView);


            model = estimator.Fit(trainData);
            mlContext.Model.Save(model, trainData.Schema, MODEL_FILEPATH);
            Console.WriteLine("Done");
        }