public static (ColumnInferenceResults columnInference, MulticlassExperimentSettings experimentSettings) SetupExperiment( MLContext mlContext, ExperimentModifier st, DataFilePaths paths, bool forPrs) { var columnInference = InferColumns(mlContext, paths.TrainPath, st.LabelColumnName); var columnInformation = columnInference.ColumnInformation; st.ColumnSetup(columnInformation, forPrs); var experimentSettings = new MulticlassExperimentSettings(); st.TrainerSetup(experimentSettings.Trainers); experimentSettings.MaxExperimentTimeInSeconds = st.ExperimentTime; var cts = new System.Threading.CancellationTokenSource(); experimentSettings.CancellationToken = cts.Token; // Set the cache directory to null. // This will cause all models produced by AutoML to be kept in memory // instead of written to disk after each run, as AutoML is training. // (Please note: for an experiment on a large dataset, opting to keep all // models trained by AutoML in memory could cause your system to run out // of memory.) experimentSettings.CacheDirectory = new DirectoryInfo(Path.GetTempPath()); experimentSettings.OptimizingMetric = MulticlassClassificationMetric.MicroAccuracy; return(columnInference, experimentSettings); }
public static ITransformer Retrain(MLContext mlContext, ExperimentResult <MulticlassClassificationMetrics> experimentResult, ColumnInferenceResults columnInference, DataFilePaths paths, bool fixedBug = false) { ConsoleHelper.ConsoleWriteHeader("=============== Re-fitting best pipeline ==============="); var textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions); var combinedDataView = textLoader.Load(new MultiFileSource(paths.TrainPath, paths.ValidatePath, paths.TestPath)); var bestRun = experimentResult.BestRun; if (fixedBug) { // TODO: retry: below gave error but I thought it would work: //refitModel = MulticlassExperiment.Retrain(experimentResult, // "final model", // new MultiFileSource(paths.TrainPath, paths.ValidatePath, paths.FittedPath), // paths.TestPath, // paths.FinalPath, textLoader, mlContext); // but if failed before fixing this maybe the problem was in *EvaluateTrainedModelAndPrintMetrics* } var refitModel = bestRun.Estimator.Fit(combinedDataView); EvaluateTrainedModelAndPrintMetrics(mlContext, refitModel, "production model", textLoader.Load(paths.TestPath)); // Save the re-fit model to a.ZIP file SaveModel(mlContext, refitModel, paths.FinalModelPath, textLoader.Load(paths.TestPath)); Trace.WriteLine("The model is saved to {0}", paths.FinalModelPath); return(refitModel); }
public static async Task PrepareAndSaveDatasetsForPrsAsync(DataFilePaths prFiles, DatasetModifier datasetModifier) { var ds = DatasetHelperInner.Instance; var lines = await ds.AddOrRemoveColumnsPriorToTrainingAsync(prFiles.InputPath, datasetModifier, includeFileColumns : true); lines = ds.OnlyPrs(lines); await ds.BreakIntoTrainValidateTestDatasetsAsync(lines, prFiles.TrainPath, prFiles.ValidatePath, prFiles.TestPath); }
public void Train(DataFilePaths files, bool forPrs) { var stopWatch = Stopwatch.StartNew(); var st = new ExperimentModifier(files, forPrs); Train(st); stopWatch.Stop(); Trace.WriteLine($"Done creating model in {stopWatch.ElapsedMilliseconds}ms"); }
public static ExperimentResult <MulticlassClassificationMetrics> Train( MLContext mlContext, string labelColumnName, MulticlassExperimentSettings experimentSettings, MulticlassExperimentProgressHandler progressHandler, DataFilePaths paths, TextLoader textLoader) { var trainData = textLoader.Load(paths.TrainPath); var validateData = textLoader.Load(paths.ValidatePath); var experimentResult = RunAutoMLExperiment(mlContext, labelColumnName, experimentSettings, progressHandler, trainData); EvaluateTrainedModelAndPrintMetrics(mlContext, experimentResult.BestRun.Model, experimentResult.BestRun.TrainerName, validateData); SaveModel(mlContext, experimentResult.BestRun.Model, paths.ModelPath, trainData); return(experimentResult); }
public ExperimentModifier(DataFilePaths paths, bool forPrs) { // set all to defaults: ColumnSetup = (columnInformation, forPrs) => { // Customize column information returned by InferColumns API columnInformation.CategoricalColumnNames.Clear(); columnInformation.NumericColumnNames.Clear(); columnInformation.IgnoredColumnNames.Clear(); columnInformation.TextColumnNames.Clear(); // NOTE: depending on how the data changes over time this might need to get updated too. columnInformation.TextColumnNames.Add("Title"); columnInformation.TextColumnNames.Add("Description"); columnInformation.CategoricalColumnNames.Add("IssueAuthor"); columnInformation.IgnoredColumnNames.Add("IsPR"); columnInformation.CategoricalColumnNames.Add("NumMentions"); columnInformation.IgnoredColumnNames.Add("UserMentions"); if (forPrs) { columnInformation.NumericColumnNames.Add("FileCount"); columnInformation.CategoricalColumnNames.Add("Files"); columnInformation.CategoricalColumnNames.Add("FolderNames"); columnInformation.CategoricalColumnNames.Add("Folders"); columnInformation.IgnoredColumnNames.Add("FileExtensions"); columnInformation.IgnoredColumnNames.Add("Filenames"); } }; TrainerSetup = (trainers) => { trainers.Clear(); if (forPrs) { trainers.Add(MulticlassClassificationTrainer.SdcaMaximumEntropy); trainers.Add(MulticlassClassificationTrainer.FastTreeOva); } else { trainers.Add(MulticlassClassificationTrainer.SdcaMaximumEntropy); // trainers.Add(MulticlassClassificationTrainer.LinearSupportVectorMachinesOva); //trainers.Add(MulticlassClassificationTrainer.LightGbm); } }; ExperimentTime = 300; LabelColumnName = "Area"; ForPrs = forPrs; Paths = paths; }
public ExperimentModifier( bool forPrs, uint experimentTime, string labelColumnName, DataFilePaths paths, Action <ColumnInformation, bool> columnSetup, Action <ICollection <MulticlassClassificationTrainer> > trainerSetup) { ForPrs = forPrs; ExperimentTime = experimentTime; LabelColumnName = labelColumnName; Paths = paths; ColumnSetup = columnSetup; TrainerSetup = trainerSetup; }
public static void TestPrediction(MLContext mlContext, DataFilePaths files, bool forPrs, double threshold = 0.6) { var trainedModel = mlContext.Model.Load(files.FittedModelPath, out _); IEnumerable <(string knownLabel, GitHubIssuePrediction predictedResult, string issueNumber)> predictions = null; string Legend1 = $"(includes not labeling issues with confidence lower than threshold. (here {threshold * 100.0f:#,0.00}%))"; const string Legend2 = "(includes items that could be labeled if threshold was lower.)"; const string Legend3 = "(those incorrectly labeled)"; if (forPrs) { var testData = GetPullRequests(mlContext, files.TestPath); Trace.WriteLine($"{Environment.NewLine}Number of PRs tested: {testData.Length}"); var prEngine = mlContext.Model.CreatePredictionEngine <GitHubPullRequest, GitHubIssuePrediction>(trainedModel); predictions = testData .Select(x => ( knownLabel: x.Area, predictedResult: prEngine.Predict(x), issueNumber: x.ID.ToString() )); } else { var testData = GetIssues(mlContext, files.TestPath); Trace.WriteLine($"{Environment.NewLine}\tNumber of issues tested: {testData.Length}"); var issueEngine = mlContext.Model.CreatePredictionEngine <GitHubIssue, GitHubIssuePrediction>(trainedModel); predictions = testData .Select(x => ( knownLabel: x.Area, predictedResult: issueEngine.Predict(x), issueNumber: x.ID.ToString() )); } var analysis = predictions.Select(x => ( knownLabel: x.knownLabel, predictedArea: x.predictedResult.Area, maxScore: x.predictedResult.Score.Max(), confidentInPrediction: x.predictedResult.Score.Max() >= threshold, issueNumber: x.issueNumber )); var countSuccess = analysis.Where(x => (x.confidentInPrediction && x.predictedArea.Equals(x.knownLabel, StringComparison.Ordinal)) || (!x.confidentInPrediction && !x.predictedArea.Equals(x.knownLabel, StringComparison.Ordinal))).Count(); var missedOpportunity = analysis .Where(x => !x.confidentInPrediction && x.knownLabel.Equals(x.predictedArea, StringComparison.Ordinal)).Count(); var mistakes = analysis .Where(x => x.confidentInPrediction && !x.knownLabel.Equals(x.predictedArea, StringComparison.Ordinal)) .Select(x => new { Pair = $"\tPredicted: {x.predictedArea}, Actual:{x.knownLabel}", IssueNumbers = x.issueNumber, MaxConfidencePercentage = x.maxScore * 100.0f }) .GroupBy(x => x.Pair) .Select(x => new { Count = x.Count(), PerdictedVsActual = x.Key, Items = x, }) .OrderByDescending(x => x.Count); int remaining = predictions.Count() - countSuccess - missedOpportunity; Trace.WriteLine($"{Environment.NewLine}\thandled correctly: {countSuccess}{Environment.NewLine}\t{Legend1}{Environment.NewLine}"); Trace.WriteLine($"{Environment.NewLine}\tmissed: {missedOpportunity}{Environment.NewLine}\t{Legend2}{Environment.NewLine}"); Trace.WriteLine($"{Environment.NewLine}\tremaining: {remaining}{Environment.NewLine}\t{Legend3}{Environment.NewLine}"); foreach (var mismatch in mistakes.AsEnumerable()) { Trace.WriteLine($"{mismatch.PerdictedVsActual}, NumFound: {mismatch.Count}"); var sampleIssues = string.Join(Environment.NewLine, mismatch.Items.Select(x => $"\t\tFor #{x.IssueNumbers} was {x.MaxConfidencePercentage:#,0.00}% confident")); Trace.WriteLine($"{Environment.NewLine}{ sampleIssues }{Environment.NewLine}"); } }
public void Test(DataFilePaths files, bool forPrs) { MulticlassExperimentHelper.TestPrediction(_mLContext, files, forPrs: forPrs); }