// Performs evaluation with the truncation level set up to 10 search results within a query. // This is a temporary workaround for this issue: https://github.com/dotnet/machinelearning/issues/2728. public static void EvaluateMetrics(MLContext mlContext, IDataView predictions, int truncationLevel) { if (truncationLevel < 1 || truncationLevel > 10) { throw new InvalidOperationException("Currently metrics are only supported for 1 to 10 truncation levels."); } // Uses reflection to set the truncation level before calling evaluate. var mlAssembly = typeof(TextLoader).Assembly; var rankEvalType = mlAssembly.DefinedTypes.Where(t => t.Name.Contains("RankingEvaluator")).First(); var evalArgsType = rankEvalType.GetNestedType("Arguments"); var evalArgs = Activator.CreateInstance(rankEvalType.GetNestedType("Arguments")); var dcgLevel = evalArgsType.GetField("DcgTruncationLevel"); dcgLevel.SetValue(evalArgs, truncationLevel); var ctor = rankEvalType.GetConstructors().First(); var evaluator = ctor.Invoke(new object[] { mlContext, evalArgs }); var evaluateMethod = rankEvalType.GetMethod("Evaluate"); RankingMetrics metrics = (RankingMetrics)evaluateMethod.Invoke(evaluator, new object[] { predictions, "Label", "GroupId", "Score" }); Console.WriteLine($"DCG: {string.Join(", ", metrics.DiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F4}").ToArray())}"); Console.WriteLine($"NDCG: {string.Join(", ", metrics.NormalizedDiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F4}").ToArray())}\n"); }
private static RankingMetrics RankingDelta( RankingMetrics a, RankingMetrics b) { var dcg = ComputeArrayDeltas(a.Dcg, b.Dcg); var ndcg = ComputeArrayDeltas(a.Ndcg, b.Ndcg); return(new RankingMetrics(dcg: dcg, ndcg: ndcg)); }
public static void PrintRankingMetrics(string name, RankingMetrics metrics, uint optimizationMetricTruncationLevel) { Console.WriteLine($"************************************************************"); Console.WriteLine($"* Metrics for {name} ranking model "); Console.WriteLine($"*-----------------------------------------------------------"); Console.WriteLine($" Normalized Discounted Cumulative Gain (NDCG@{optimizationMetricTruncationLevel}) = {metrics?.NormalizedDiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}, a value from 0 and 1, where closer to 1.0 is better"); Console.WriteLine($" Discounted Cumulative Gain (DCG@{optimizationMetricTruncationLevel}) = {metrics?.DiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}"); }
public static void EvaluateMetrics(MLContext mlContext, IDataView predictions) { RankingMetrics metrics = mlContext.Ranking.Evaluate(predictions); Console.WriteLine($"DCG: {string.Join(", ", metrics.DiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F4}").ToArray())}"); Console.WriteLine($"NDCG: {string.Join(", ", metrics.NormalizedDiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F4}").ToArray())}\n"); }
// To evaluate the accuracy of the model's predicted rankings, prints out the Discounted Cumulative Gain and Normalized Discounted Cumulative Gain for search queries. public static void EvaluateMetrics(MLContext mlContext, IDataView predictions) { // Evaluate the metrics for the data using NDCG; by default, metrics for the up to 3 search results in the query are reported (e.g. NDCG@3). RankingMetrics metrics = mlContext.Ranking.Evaluate(predictions); Console.WriteLine($"DCG: {string.Join(", ", metrics.DiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F4}").ToArray())}"); Console.WriteLine($"NDCG: {string.Join(", ", metrics.NormalizedDiscountedCumulativeGains.Select((d, i) => $"@{i + 1}:{d:F4}").ToArray())}\n"); }
// Pretty-print RankerMetrics objects. public static void PrintMetrics(RankingMetrics metrics) { Console.WriteLine("DCG: " + string.Join(", ", metrics.DiscountedCumulativeGains.Select( (d, i) => (i + 1) + ":" + d + ":F2").ToArray())); Console.WriteLine("NDCG: " + string.Join(", ", metrics.NormalizedDiscountedCumulativeGains.Select( (d, i) => (i + 1) + ":" + d + ":F2").ToArray())); }
/// <summary> /// Check that a <see cref="RankingMetrics"/> object is valid. /// </summary> /// <param name="metrics">The metrics object.</param> public static void AssertMetrics(RankingMetrics metrics) { foreach (var dcg in metrics.Dcg) { Assert.True(dcg >= 0); } foreach (var ndcg in metrics.Ndcg) { Assert.InRange(ndcg, 0, 100); } }
/// <summary> /// Check that a <see cref="RankingMetrics"/> object is valid. /// </summary> /// <param name="metrics">The metrics object.</param> public static void AssertMetrics(RankingMetrics metrics) { foreach (var dcg in metrics.DiscountedCumulativeGains) { Assert.True(dcg >= 0); } foreach (var ndcg in metrics.NormalizedDiscountedCumulativeGains) { Assert.InRange(ndcg, 0, 100); } }
private static double GetScore(RankingMetrics metrics, RankingMetric metric) { return(new RankingMetricsAgent(null, metric).GetScore(metrics)); }
private static void PrintMetrics(RankingMetrics metrics) { Console.WriteLine($"NormalizedDiscountedCumulativeGains: {metrics.NormalizedDiscountedCumulativeGains}"); Console.WriteLine($"DiscountedCumulativeGains: {metrics.DiscountedCumulativeGains}"); }
public static void Run() { MLContext mlContext = new MLContext(); // STEP 1: Load data IDataView trainDataView = mlContext.Data.LoadFromTextFile <SearchData>(TrainDataPath, hasHeader: true, separatorChar: ','); IDataView testDataView = mlContext.Data.LoadFromTextFile <SearchData>(TestDataPath, hasHeader: true, separatorChar: ','); // STEP 2: Run AutoML experiment Console.WriteLine($"Running AutoML recommendation experiment for {ExperimentTime} seconds..."); ExperimentResult <RankingMetrics> experimentResult = mlContext.Auto() .CreateRankingExperiment(new RankingExperimentSettings() { MaxExperimentTimeInSeconds = ExperimentTime }) .Execute(trainDataView, testDataView, new ColumnInformation() { LabelColumnName = LabelColumnName, GroupIdColumnName = GroupColumnName }); // STEP 3: Print metric from best model RunDetail <RankingMetrics> bestRun = experimentResult.BestRun; Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}"); Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}"); Console.WriteLine($"Metrics of best model from validation data --"); PrintMetrics(bestRun.ValidationMetrics); // STEP 5: Evaluate test data IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView); RankingMetrics testMetrics = mlContext.Ranking.Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName); Console.WriteLine($"Metrics of best model on test data --"); PrintMetrics(testMetrics); // STEP 6: Save the best model for later deployment and inferencing mlContext.Model.Save(bestRun.Model, trainDataView.Schema, ModelPath); // STEP 7: Create prediction engine from the best trained model var predictionEngine = mlContext.Model.CreatePredictionEngine <SearchData, SearchDataPrediction>(bestRun.Model); // STEP 8: Initialize a new test, and get the prediction var testPage = new SearchData { GroupId = "1", Features = 9, Label = 1 }; var prediction = predictionEngine.Predict(testPage); Console.WriteLine($"Predicted rating for: {prediction.Prediction}"); // New Page testPage = new SearchData { GroupId = "2", Features = 2, Label = 9 }; prediction = predictionEngine.Predict(testPage); Console.WriteLine($"Predicted: {prediction.Prediction}"); Console.WriteLine("Press any key to continue..."); Console.ReadKey(); }
private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, uint dcgTruncationLevel) { var metricsAgent = new RankingMetricsAgent(null, metric, dcgTruncationLevel); return(IsPerfectModel(metricsAgent, metrics)); }
private static double GetScore(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName) { return(new RankingMetricsAgent(null, metric, groupIdColumnName).GetScore(metrics)); }
private static double GetScore(RankingMetrics metrics, RankingMetric metric, uint dcgTruncationLevel) { return(new RankingMetricsAgent(null, metric, dcgTruncationLevel).GetScore(metrics)); }
private static TMetrics GetAverageMetrics(IEnumerable <TMetrics> metrics, TMetrics metricsClosestToAvg) { if (typeof(TMetrics) == typeof(BinaryClassificationMetrics)) { var newMetrics = metrics.Select(x => x as BinaryClassificationMetrics); Contracts.Assert(newMetrics != null); var result = new BinaryClassificationMetrics( auc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderRocCurve)), accuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.Accuracy)), positivePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositivePrecision)), positiveRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.PositiveRecall)), negativePrecision: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativePrecision)), negativeRecall: GetAverageOfNonNaNScores(newMetrics.Select(x => x.NegativeRecall)), f1Score: GetAverageOfNonNaNScores(newMetrics.Select(x => x.F1Score)), auprc: GetAverageOfNonNaNScores(newMetrics.Select(x => x.AreaUnderPrecisionRecallCurve)), // Return ConfusionMatrix from the fold closest to average score confusionMatrix: (metricsClosestToAvg as BinaryClassificationMetrics).ConfusionMatrix); return(result as TMetrics); } if (typeof(TMetrics) == typeof(MulticlassClassificationMetrics)) { var newMetrics = metrics.Select(x => x as MulticlassClassificationMetrics); Contracts.Assert(newMetrics != null); var result = new MulticlassClassificationMetrics( accuracyMicro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MicroAccuracy)), accuracyMacro: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MacroAccuracy)), logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)), logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)), topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount, topKAccuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.TopKAccuracy)), // Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(), confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix); return(result as TMetrics); } if (typeof(TMetrics) == typeof(RegressionMetrics)) { var newMetrics = metrics.Select(x => x as RegressionMetrics); Contracts.Assert(newMetrics != null); var result = new RegressionMetrics( l1: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanAbsoluteError)), l2: GetAverageOfNonNaNScores(newMetrics.Select(x => x.MeanSquaredError)), rms: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RootMeanSquaredError)), lossFunction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LossFunction)), rSquared: GetAverageOfNonNaNScores(newMetrics.Select(x => x.RSquared))); return(result as TMetrics); } if (typeof(TMetrics) == typeof(RankingMetrics)) { var newMetrics = metrics.Select(x => x as RankingMetrics); Contracts.Assert(newMetrics != null); var result = new RankingMetrics( dcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.DiscountedCumulativeGains)), ndcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.NormalizedDiscountedCumulativeGains))); return(result as TMetrics); } throw new NotImplementedException($"Metric {typeof(TMetrics)} not implemented"); }
private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric) { var metricsAgent = new RankingMetricsAgent(null, metric); return(IsPerfectModel(metricsAgent, metrics)); }
internal static void PrintIterationMetrics(int iteration, string trainerName, RankingMetrics metrics, double?runtimeInSeconds) { CreateRow($"{iteration,-4} {trainerName,-15} {metrics?.NormalizedDiscountedCumulativeGains[0] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[2] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[9] ?? double.NaN,9:F4} {metrics?.DiscountedCumulativeGains[9] ?? double.NaN,9:F4} {runtimeInSeconds.Value,9:F1}", Width); }
private static bool IsPerfectModel(RankingMetrics metrics, RankingMetric metric, string groupIdColumnName) { var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName); return(IsPerfectModel(metricsAgent, metrics)); }