/// <summary> /// FastTree <see cref="RankingCatalog"/>. /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the <see cref="FastTreeRankingTrainer"/>. /// </summary> /// <param name="catalog">The <see cref="RegressionCatalog"/>.</param> /// <param name="label">The label column.</param> /// <param name="features">The features column.</param> /// <param name="groupId">The groupId column.</param> /// <param name="weights">The optional weights column.</param> /// <param name="options">Algorithm advanced settings.</param> /// <param name="onFit">A delegate that is called every time the /// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the /// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive /// the linear model that was trained. Note that this action cannot change the result in any way; /// it is only a way for the caller to be informed about what was learnt.</param> /// <returns>The Score output column indicating the predicted value.</returns> public static Scalar <float> FastTree <TVal>(this RankingCatalog.RankingTrainers catalog, Scalar <float> label, Vector <float> features, Key <uint, TVal> groupId, Scalar <float> weights, FastTreeRankingTrainer.Options options, Action <FastTreeRankingModelParameters> onFit = null) { Contracts.CheckValueOrNull(options); CheckUserValues(label, features, weights, onFit); var rec = new TrainerEstimatorReconciler.Ranker <TVal>( (env, labelName, featuresName, groupIdName, weightsName) => { options.LabelColumnName = labelName; options.FeatureColumnName = featuresName; options.RowGroupColumnName = groupIdName; options.ExampleWeightColumnName = weightsName; var trainer = new FastTreeRankingTrainer(env, options); if (onFit != null) { return(trainer.WithOnFitDelegate(trans => onFit(trans.Model))); } return(trainer); }, label, features, groupId, weights); return(rec.Score); }
/// <summary> /// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the <see cref="FastTreeRankingTrainer"/>. /// </summary> /// <param name="ctx">The <see cref="RankingContext"/>.</param> /// <param name="options">Algorithm advanced settings.</param> public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainers ctx, FastTreeRankingTrainer.Options options) { Contracts.CheckValue(ctx, nameof(ctx)); var env = CatalogUtils.GetEnvironment(ctx); return(new FastTreeRankingTrainer(env, options)); }
public void TestFastTreeRankingFeaturizationInPipeline() { int dataPointCount = 200; var data = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(dataPointCount).ToList(); var dataView = ML.Data.LoadFromEnumerable(data); dataView = ML.Data.Cache(dataView); var trainerOptions = new FastTreeRankingTrainer.Options { NumberOfThreads = 1, NumberOfTrees = 10, NumberOfLeaves = 4, MinimumExampleCountPerLeaf = 10, FeatureColumnName = "Features", LabelColumnName = "Label" }; var options = new FastTreeRankingFeaturizationEstimator.Options() { InputColumnName = "Features", TreesColumnName = "Trees", LeavesColumnName = "Leaves", PathsColumnName = "Paths", TrainerOptions = trainerOptions }; var pipeline = ML.Transforms.FeaturizeByFastTreeRanking(options) .Append(ML.Transforms.Concatenate("CombinedFeatures", "Features", "Trees", "Leaves", "Paths")) .Append(ML.Regression.Trainers.Sdca("Label", "CombinedFeatures")); var model = pipeline.Fit(dataView); var prediction = model.Transform(dataView); var metrics = ML.Regression.Evaluate(prediction); Assert.True(metrics.MeanAbsoluteError < 0.25); Assert.True(metrics.MeanSquaredError < 0.1); }
// This example requires installation of additional NuGet package // <a href="https://www.nuget.org/packages/Microsoft.ML.FastTree/">Microsoft.ML.FastTree</a>. public static void Example() { // Create a new context for ML.NET operations. It can be used for // exception tracking and logging, as a catalog of available operations // and as the source of randomness. Setting the seed to a fixed number // in this example to make outputs deterministic. var mlContext = new MLContext(seed: 0); // Create a list of training data points. var dataPoints = GenerateRandomDataPoints(100).ToList(); // Convert the list of data points to an IDataView object, which is // consumable by ML.NET API. var dataView = mlContext.Data.LoadFromEnumerable(dataPoints); // ML.NET doesn't cache data set by default. Therefore, if one reads a // data set from a file and accesses it many times, it can be slow due // to expensive featurization and disk operations. When the considered // data can fit into memory, a solution is to cache the data in memory. // Caching is especially helpful when working with iterative algorithms // which needs many data passes. dataView = mlContext.Data.Cache(dataView); // Define input and output columns of tree-based featurizer. string labelColumnName = nameof(DataPoint.Label); string featureColumnName = nameof(DataPoint.Features); string treesColumnName = nameof(TransformedDataPoint.Trees); string leavesColumnName = nameof(TransformedDataPoint.Leaves); string pathsColumnName = nameof(TransformedDataPoint.Paths); // Define the configuration of the trainer used to train a tree-based // model. var trainerOptions = new FastTreeRankingTrainer.Options { // Reduce the number of trees to 3. NumberOfTrees = 3, // Number of leaves per tree. NumberOfLeaves = 6, // Feature column name. FeatureColumnName = featureColumnName, // Label column name. LabelColumnName = labelColumnName }; // Define the tree-based featurizer's configuration. var options = new FastTreeRankingFeaturizationEstimator.Options { InputColumnName = featureColumnName, TreesColumnName = treesColumnName, LeavesColumnName = leavesColumnName, PathsColumnName = pathsColumnName, TrainerOptions = trainerOptions }; // Define the featurizer. var pipeline = mlContext.Transforms.FeaturizeByFastTreeRanking( options); // Train the model. var model = pipeline.Fit(dataView); // Apply the trained transformer to the considered data set. var transformed = model.Transform(dataView); // Convert IDataView object to a list. Each element in the resulted list // corresponds to a row in the IDataView. var transformedDataPoints = mlContext.Data.CreateEnumerable < TransformedDataPoint>(transformed, false).ToList(); // Print out the transformation of the first 3 data points. for (int i = 0; i < 3; ++i) { var dataPoint = dataPoints[i]; var transformedDataPoint = transformedDataPoints[i]; Console.WriteLine("The original feature vector [" + String.Join(",", dataPoint.Features) + "] is transformed to three different " + "tree-based feature vectors:"); Console.WriteLine(" Trees' output values: [" + String.Join(",", transformedDataPoint.Trees) + "]."); Console.WriteLine(" Leave IDs' 0-1 representation: [" + String .Join(",", transformedDataPoint.Leaves) + "]."); Console.WriteLine(" Paths IDs' 0-1 representation: [" + String .Join(",", transformedDataPoint.Paths) + "]."); } // Expected output: // The original feature vector [1.117325,1.068023,0.8581612] is // transformed to three different tree-based feature vectors: // Trees' output values: [0.4095458,0.2061437,0.2364294]. // Leave IDs' 0-1 representation: [0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1]. // Paths IDs' 0-1 representation: [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]. // The original feature vector [0.6588848,1.006027,0.5421779] is // transformed to three different tree-based feature vectors: // Trees' output values: [0.2543825,-0.06570309,-0.1456212]. // Leave IDs' 0-1 representation: [0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0]. // Paths IDs' 0-1 representation: [1,1,1,1,1,1,1,1,1,1,1,1,1,1,0]. // The original feature vector [0.6737045,0.6919063,0.8673147] is // transformed to three different tree-based feature vectors: // Trees' output values: [0.2543825,-0.06570309,0.01300209]. // Leave IDs' 0-1 representation: [0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,1,0]. // Paths IDs' 0-1 representation: [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]. }
// This example requires installation of additional NuGet package // <a href="https://www.nuget.org/packages/Microsoft.ML.FastTree/">Microsoft.ML.FastTree</a>. public static void Example() { // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. // Setting the seed to a fixed number in this example to make outputs deterministic. var mlContext = new MLContext(seed: 0); // Create a list of training data points. var dataPoints = GenerateRandomDataPoints(1000); // Convert the list of data points to an IDataView object, which is consumable by ML.NET API. var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); // Define trainer options. var options = new FastTreeRankingTrainer.Options { // Use NdcgAt3 for early stopping. EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt3, // Create a simpler model by penalizing usage of new features. FeatureFirstUsePenalty = 0.1, // Reduce the number of trees to 50. NumberOfTrees = 50, // Specify the row group column name. RowGroupColumnName = "GroupId" }; // Define the trainer. var pipeline = mlContext.Ranking.Trainers.FastTree(options); // Train the model. var model = pipeline.Fit(trainingData); // Create testing data. Use different random seed to make it different from training data. var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123)); // Run the model on test data set. var transformedTestData = model.Transform(testData); // Take the top 5 rows. var topTransformedTestData = mlContext.Data.TakeRows(transformedTestData, 5); // Convert IDataView object to a list. var predictions = mlContext.Data.CreateEnumerable <Prediction>(topTransformedTestData, reuseRowObject: false).ToList(); // Print 5 predictions. foreach (var p in predictions) { Console.WriteLine($"Label: {p.Label}, Score: {p.Score}"); } // Expected output: // Label: 5, Score: 8.807633 // Label: 1, Score: -10.71331 // Label: 3, Score: -8.134147 // Label: 3, Score: -6.545538 // Label: 1, Score: -10.27982 // Evaluate the overall metrics. var metrics = mlContext.Ranking.Evaluate(transformedTestData); PrintMetrics(metrics); // Expected output: // DCG: @1:40.57, @2:61.21, @3:74.11 // NDCG: @1:0.96, @2:0.95, @3:0.97 }