/// <summary> /// Create <see cref="OlsTrainer"/> with advanced options, which predicts a target using a linear regression model. /// </summary> /// <param name="catalog">The <see cref="RegressionCatalog"/>.</param> /// <param name="options">Algorithm advanced options. See <see cref="OlsTrainer.Options"/>.</param> /// <example> /// <format type="text/markdown"> /// <![CDATA[ /// [!code-csharp[OrdinaryLeastSquares](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs)] /// ]]> /// </format> /// </example> public static OlsTrainer Ols( this RegressionCatalog.RegressionTrainers catalog, OlsTrainer.Options options) { Contracts.CheckValue(catalog, nameof(catalog)); Contracts.CheckValue(options, nameof(options)); var env = CatalogUtils.GetEnvironment(catalog); return new OlsTrainer(env, options); }
/// <summary> /// Predict a target using a linear regression model trained with the <see cref="OlsTrainer"/>. /// </summary> /// <param name="catalog">The <see cref="RegressionCatalog"/>.</param> /// <param name="labelColumnName">The name of the label column.</param> /// <param name="featureColumnName">The name of the feature column.</param> /// <param name="exampleWeightColumnName">The name of the example weight column (optional).</param> /// <example> /// <format type="text/markdown"> /// <![CDATA[ /// [!code-csharp[OrdinaryLeastSquares](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquares.cs)] /// ]]> /// </format> /// </example> public static OlsTrainer Ols(this RegressionCatalog.RegressionTrainers catalog, string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null) { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); var options = new OlsTrainer.Options { LabelColumnName = labelColumnName, FeatureColumnName = featureColumnName, ExampleWeightColumnName = exampleWeightColumnName }; return(new OlsTrainer(env, options)); }
public static void Example() { // Create a new context for ML.NET operations. It can be used for // exception tracking and logging, as a catalog of available operations // and as the source of randomness. Setting the seed to a fixed number // in this example to make outputs deterministic. var mlContext = new MLContext(seed: 0); // Create a list of training data points. var dataPoints = GenerateRandomDataPoints(1000); // Convert the list of data points to an IDataView object, which is // consumable by ML.NET API. var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); // Define trainer options. var options = new OlsTrainer.Options { LabelColumnName = nameof(DataPoint.Label), FeatureColumnName = nameof(DataPoint.Features), // Larger values leads to smaller (closer to zero) model parameters. L2Regularization = 0.1f, // Whether to compute standard error and other statistics of model // parameters. CalculateStatistics = false }; // Define the trainer. var pipeline = mlContext.Regression.Trainers.Ols(options); // Train the model. var model = pipeline.Fit(trainingData); // Create testing data. Use different random seed to make it different // from training data. var testData = mlContext.Data.LoadFromEnumerable( GenerateRandomDataPoints(5, seed: 123)); // Run the model on test data set. var transformedTestData = model.Transform(testData); // Convert IDataView object to a list. var predictions = mlContext.Data.CreateEnumerable <Prediction>( transformedTestData, reuseRowObject: false).ToList(); // Look at 5 predictions for the Label, side by side with the actual // Label for comparison. foreach (var p in predictions) { Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); } // Expected output: // Label: 0.985, Prediction: 0.960 // Label: 0.155, Prediction: 0.075 // Label: 0.515, Prediction: 0.456 // Label: 0.566, Prediction: 0.499 // Label: 0.096, Prediction: 0.080 // Evaluate the overall metrics var metrics = mlContext.Regression.Evaluate(transformedTestData); PrintMetrics(metrics); // Expected output: // Mean Absolute Error: 0.05 // Mean Squared Error: 0.00 // Root Mean Squared Error: 0.06 // RSquared: 0.97 (closer to 1 is better. The worst case is 0) }