예제 #1
0
        public static (ColumnInferenceResults columnInference, MulticlassExperimentSettings experimentSettings) SetupExperiment(
            MLContext mlContext, ExperimentModifier st, DataFilePaths paths, bool forPrs)
        {
            var columnInference   = InferColumns(mlContext, paths.TrainPath, st.LabelColumnName);
            var columnInformation = columnInference.ColumnInformation;

            st.ColumnSetup(columnInformation, forPrs);

            var experimentSettings = new MulticlassExperimentSettings();

            st.TrainerSetup(experimentSettings.Trainers);
            experimentSettings.MaxExperimentTimeInSeconds = st.ExperimentTime;

            var cts = new System.Threading.CancellationTokenSource();

            experimentSettings.CancellationToken = cts.Token;

            // Set the cache directory to null.
            // This will cause all models produced by AutoML to be kept in memory
            // instead of written to disk after each run, as AutoML is training.
            // (Please note: for an experiment on a large dataset, opting to keep all
            // models trained by AutoML in memory could cause your system to run out
            // of memory.)
            experimentSettings.CacheDirectory   = new DirectoryInfo(Path.GetTempPath());
            experimentSettings.OptimizingMetric = MulticlassClassificationMetric.MicroAccuracy;
            return(columnInference, experimentSettings);
        }
예제 #2
0
        public static ITransformer Retrain(MLContext mlContext, ExperimentResult <MulticlassClassificationMetrics> experimentResult,
                                           ColumnInferenceResults columnInference, DataFilePaths paths, bool fixedBug = false)
        {
            ConsoleHelper.ConsoleWriteHeader("=============== Re-fitting best pipeline ===============");
            var textLoader       = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
            var combinedDataView = textLoader.Load(new MultiFileSource(paths.TrainPath, paths.ValidatePath, paths.TestPath));
            var bestRun          = experimentResult.BestRun;

            if (fixedBug)
            {
                // TODO: retry: below gave error but I thought it would work:
                //refitModel = MulticlassExperiment.Retrain(experimentResult,
                //    "final model",
                //    new MultiFileSource(paths.TrainPath, paths.ValidatePath, paths.FittedPath),
                //    paths.TestPath,
                //    paths.FinalPath, textLoader, mlContext);
                // but if failed before fixing this maybe the problem was in *EvaluateTrainedModelAndPrintMetrics*
            }
            var refitModel = bestRun.Estimator.Fit(combinedDataView);

            EvaluateTrainedModelAndPrintMetrics(mlContext, refitModel, "production model", textLoader.Load(paths.TestPath));
            // Save the re-fit model to a.ZIP file
            SaveModel(mlContext, refitModel, paths.FinalModelPath, textLoader.Load(paths.TestPath));

            Trace.WriteLine("The model is saved to {0}", paths.FinalModelPath);
            return(refitModel);
        }
예제 #3
0
        public static async Task PrepareAndSaveDatasetsForPrsAsync(DataFilePaths prFiles, DatasetModifier datasetModifier)
        {
            var ds = DatasetHelperInner.Instance;

            var lines = await ds.AddOrRemoveColumnsPriorToTrainingAsync(prFiles.InputPath, datasetModifier, includeFileColumns : true);

            lines = ds.OnlyPrs(lines);
            await ds.BreakIntoTrainValidateTestDatasetsAsync(lines, prFiles.TrainPath, prFiles.ValidatePath, prFiles.TestPath);
        }
예제 #4
0
        public void Train(DataFilePaths files, bool forPrs)
        {
            var stopWatch = Stopwatch.StartNew();

            var st = new ExperimentModifier(files, forPrs);

            Train(st);

            stopWatch.Stop();
            Trace.WriteLine($"Done creating model in {stopWatch.ElapsedMilliseconds}ms");
        }
예제 #5
0
        public static ExperimentResult <MulticlassClassificationMetrics> Train(
            MLContext mlContext, string labelColumnName, MulticlassExperimentSettings experimentSettings,
            MulticlassExperimentProgressHandler progressHandler, DataFilePaths paths, TextLoader textLoader)
        {
            var trainData        = textLoader.Load(paths.TrainPath);
            var validateData     = textLoader.Load(paths.ValidatePath);
            var experimentResult = RunAutoMLExperiment(mlContext, labelColumnName, experimentSettings, progressHandler, trainData);

            EvaluateTrainedModelAndPrintMetrics(mlContext, experimentResult.BestRun.Model, experimentResult.BestRun.TrainerName, validateData);
            SaveModel(mlContext, experimentResult.BestRun.Model, paths.ModelPath, trainData);
            return(experimentResult);
        }
        public ExperimentModifier(DataFilePaths paths, bool forPrs)
        {
            // set all to defaults:
            ColumnSetup = (columnInformation, forPrs) =>
            {
                // Customize column information returned by InferColumns API
                columnInformation.CategoricalColumnNames.Clear();
                columnInformation.NumericColumnNames.Clear();
                columnInformation.IgnoredColumnNames.Clear();
                columnInformation.TextColumnNames.Clear();

                // NOTE: depending on how the data changes over time this might need to get updated too.
                columnInformation.TextColumnNames.Add("Title");
                columnInformation.TextColumnNames.Add("Description");
                columnInformation.CategoricalColumnNames.Add("IssueAuthor");
                columnInformation.IgnoredColumnNames.Add("IsPR");
                columnInformation.CategoricalColumnNames.Add("NumMentions");
                columnInformation.IgnoredColumnNames.Add("UserMentions");

                if (forPrs)
                {
                    columnInformation.NumericColumnNames.Add("FileCount");
                    columnInformation.CategoricalColumnNames.Add("Files");
                    columnInformation.CategoricalColumnNames.Add("FolderNames");
                    columnInformation.CategoricalColumnNames.Add("Folders");
                    columnInformation.IgnoredColumnNames.Add("FileExtensions");
                    columnInformation.IgnoredColumnNames.Add("Filenames");
                }
            };

            TrainerSetup = (trainers) =>
            {
                trainers.Clear();
                if (forPrs)
                {
                    trainers.Add(MulticlassClassificationTrainer.SdcaMaximumEntropy);
                    trainers.Add(MulticlassClassificationTrainer.FastTreeOva);
                }
                else
                {
                    trainers.Add(MulticlassClassificationTrainer.SdcaMaximumEntropy);
                    // trainers.Add(MulticlassClassificationTrainer.LinearSupportVectorMachinesOva);
                    //trainers.Add(MulticlassClassificationTrainer.LightGbm);
                }
            };

            ExperimentTime  = 300;
            LabelColumnName = "Area";
            ForPrs          = forPrs;
            Paths           = paths;
        }
 public ExperimentModifier(
     bool forPrs,
     uint experimentTime,
     string labelColumnName,
     DataFilePaths paths,
     Action <ColumnInformation, bool> columnSetup,
     Action <ICollection <MulticlassClassificationTrainer> > trainerSetup)
 {
     ForPrs          = forPrs;
     ExperimentTime  = experimentTime;
     LabelColumnName = labelColumnName;
     Paths           = paths;
     ColumnSetup     = columnSetup;
     TrainerSetup    = trainerSetup;
 }
예제 #8
0
        public static void TestPrediction(MLContext mlContext, DataFilePaths files, bool forPrs, double threshold = 0.6)
        {
            var trainedModel = mlContext.Model.Load(files.FittedModelPath, out _);
            IEnumerable <(string knownLabel, GitHubIssuePrediction predictedResult, string issueNumber)> predictions = null;
            string       Legend1 = $"(includes not labeling issues with confidence lower than threshold. (here {threshold * 100.0f:#,0.00}%))";
            const string Legend2 = "(includes items that could be labeled if threshold was lower.)";
            const string Legend3 = "(those incorrectly labeled)";

            if (forPrs)
            {
                var testData = GetPullRequests(mlContext, files.TestPath);
                Trace.WriteLine($"{Environment.NewLine}Number of PRs tested: {testData.Length}");
                var prEngine = mlContext.Model.CreatePredictionEngine <GitHubPullRequest, GitHubIssuePrediction>(trainedModel);
                predictions = testData
                              .Select(x => (
                                          knownLabel: x.Area,
                                          predictedResult: prEngine.Predict(x),
                                          issueNumber: x.ID.ToString()
                                          ));
            }
            else
            {
                var testData = GetIssues(mlContext, files.TestPath);
                Trace.WriteLine($"{Environment.NewLine}\tNumber of issues tested: {testData.Length}");
                var issueEngine = mlContext.Model.CreatePredictionEngine <GitHubIssue, GitHubIssuePrediction>(trainedModel);
                predictions = testData
                              .Select(x => (
                                          knownLabel: x.Area,
                                          predictedResult: issueEngine.Predict(x),
                                          issueNumber: x.ID.ToString()
                                          ));
            }

            var analysis =
                predictions.Select(x =>
                                   (
                                       knownLabel: x.knownLabel,
                                       predictedArea: x.predictedResult.Area,
                                       maxScore: x.predictedResult.Score.Max(),
                                       confidentInPrediction: x.predictedResult.Score.Max() >= threshold,
                                       issueNumber: x.issueNumber
                                   ));

            var countSuccess = analysis.Where(x =>
                                              (x.confidentInPrediction && x.predictedArea.Equals(x.knownLabel, StringComparison.Ordinal)) ||
                                              (!x.confidentInPrediction && !x.predictedArea.Equals(x.knownLabel, StringComparison.Ordinal))).Count();

            var missedOpportunity = analysis
                                    .Where(x => !x.confidentInPrediction && x.knownLabel.Equals(x.predictedArea, StringComparison.Ordinal)).Count();

            var mistakes = analysis
                           .Where(x => x.confidentInPrediction && !x.knownLabel.Equals(x.predictedArea, StringComparison.Ordinal))
                           .Select(x => new { Pair = $"\tPredicted: {x.predictedArea}, Actual:{x.knownLabel}", IssueNumbers = x.issueNumber, MaxConfidencePercentage = x.maxScore * 100.0f })
                           .GroupBy(x => x.Pair)
                           .Select(x => new
            {
                Count             = x.Count(),
                PerdictedVsActual = x.Key,
                Items             = x,
            })
                           .OrderByDescending(x => x.Count);
            int remaining = predictions.Count() - countSuccess - missedOpportunity;

            Trace.WriteLine($"{Environment.NewLine}\thandled correctly: {countSuccess}{Environment.NewLine}\t{Legend1}{Environment.NewLine}");
            Trace.WriteLine($"{Environment.NewLine}\tmissed: {missedOpportunity}{Environment.NewLine}\t{Legend2}{Environment.NewLine}");
            Trace.WriteLine($"{Environment.NewLine}\tremaining: {remaining}{Environment.NewLine}\t{Legend3}{Environment.NewLine}");
            foreach (var mismatch in mistakes.AsEnumerable())
            {
                Trace.WriteLine($"{mismatch.PerdictedVsActual}, NumFound: {mismatch.Count}");
                var sampleIssues = string.Join(Environment.NewLine, mismatch.Items.Select(x => $"\t\tFor #{x.IssueNumbers} was {x.MaxConfidencePercentage:#,0.00}% confident"));
                Trace.WriteLine($"{Environment.NewLine}{ sampleIssues }{Environment.NewLine}");
            }
        }
예제 #9
0
 public void Test(DataFilePaths files, bool forPrs)
 {
     MulticlassExperimentHelper.TestPrediction(_mLContext, files, forPrs: forPrs);
 }