/// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string datasetFile               = string.Empty;
            string trainedModelFile          = string.Empty;
            string predictionsFile           = string.Empty;
            int    maxRecommendedItemCount   = 5;
            int    minRecommendationPoolSize = 5;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile            = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile              = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--max-items", "NUM", "Maximum number of items to recommend; defaults to 5", v => maxRecommendedItemCount = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--min-pool-size", "NUM", "Minimum size of the recommendation pool for a single user; defaults to 5", v => minRecommendationPoolSize = v, CommandLineParameterType.Optional);
            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);

            var trainedModel = MatchboxRecommender.Load <RecommenderDataset, User, Item, DummyFeatureSource>(trainedModelFile);
            var evaluator    = new RecommenderEvaluator <RecommenderDataset, User, Item, int, int, Discrete>(
                Mappings.StarRatingRecommender.ForEvaluation());
            IDictionary <User, IEnumerable <Item> > itemRecommendations = evaluator.RecommendRatedItems(
                trainedModel, testDataset, maxRecommendedItemCount, minRecommendationPoolSize);

            RecommenderPersistenceUtils.SaveRecommendedItems(predictionsFile, itemRecommendations);

            return(true);
        }
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string testSetFile     = string.Empty;
            string modelFile       = string.Empty;
            string predictionsFile = string.Empty;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--test-set", "FILE", "File with test data", v => testSetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--model", "FILE", "File with a trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--predictions", "FILE", "File to store predictions for the test data", v => predictionsFile   = v, CommandLineParameterType.Required);
            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            var testSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(testSetFile);

            BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(testSet);

            var classifier =
                BayesPointMachineClassifier.LoadMulticlassClassifier <IList <LabeledFeatureValues>, LabeledFeatureValues, IList <LabelDistribution>, string, IDictionary <string, double> >(modelFile);

            // Predict labels
            var predictions = classifier.PredictDistribution(testSet);

            // Write labels to file
            ClassifierPersistenceUtils.SaveLabelDistributions(predictionsFile, predictions);

            return(true);
        }
Exemple #3
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string testDatasetFile = string.Empty;
            string predictionsFile = string.Empty;
            string reportFile      = string.Empty;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile            = v, CommandLineParameterType.Required);
            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
            IDictionary <User, IDictionary <Item, int> > ratingPredictions = RecommenderPersistenceUtils.LoadPredictedRatings(predictionsFile);

            var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
            var evaluator        = new StarRatingRecommenderEvaluator <RecommenderDataset, User, Item, int>(evaluatorMapping);

            using (var writer = new StreamWriter(reportFile))
            {
                writer.WriteLine(
                    "Mean absolute error: {0:0.000}",
                    evaluator.RatingPredictionMetric(testDataset, ratingPredictions, Metrics.AbsoluteError));
                writer.WriteLine(
                    "Root mean squared error: {0:0.000}",
                    Math.Sqrt(evaluator.RatingPredictionMetric(testDataset, ratingPredictions, Metrics.SquaredError)));
            }

            return(true);
        }
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string inputDatasetFile  = string.Empty;
            string outputDatasetFile = string.Empty;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--input-data", "FILE", "Input dataset, treated as if all the ratings are positive", v => inputDatasetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--output-data", "FILE", "Output dataset with both posisitve and negative data", v => outputDatasetFile    = v, CommandLineParameterType.Required);

            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            var generatorMapping = Mappings.StarRatingRecommender.WithGeneratedNegativeData();

            var inputDataset  = RecommenderDataset.Load(inputDatasetFile);
            var outputDataset = new RecommenderDataset(
                generatorMapping.GetInstances(inputDataset).Select(i => new RatedUserItem(i.User, i.Item, i.Rating)),
                generatorMapping.GetRatingInfo(inputDataset));

            outputDataset.Save(outputDatasetFile);

            return(true);
        }
Exemple #5
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string inputDatasetFile               = string.Empty;
            string outputTrainingDatasetFile      = string.Empty;
            string outputTestDatasetFile          = string.Empty;
            double trainingOnlyUserFraction       = 0.5;
            double testUserRatingTrainingFraction = 0.25;
            double coldUserFraction               = 0;
            double coldItemFraction               = 0;
            double ignoredUserFraction            = 0;
            double ignoredItemFraction            = 0;
            bool   removeOccasionalColdItems      = false;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--input-data", "FILE", "Dataset to split", v => inputDatasetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--output-data-train", "FILE", "Training part of the split dataset", v => outputTrainingDatasetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--output-data-test", "FILE", "Test part of the split dataset", v => outputTestDatasetFile          = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--training-users", "NUM", "Fraction of training-only users; defaults to 0.5", (double v) => trainingOnlyUserFraction = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--test-user-training-ratings", "NUM", "Fraction of test user ratings for training; defaults to 0.25", (double v) => testUserRatingTrainingFraction = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--cold-users", "NUM", "Fraction of cold (test-only) users; defaults to 0", (double v) => coldUserFraction   = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--cold-items", "NUM", "Fraction of cold (test-only) items; defaults to 0", (double v) => coldItemFraction   = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--ignored-users", "NUM", "Fraction of ignored users; defaults to 0", (double v) => ignoredUserFraction      = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--ignored-items", "NUM", "Fraction of ignored items; defaults to 0", (double v) => ignoredItemFraction      = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--remove-occasional-cold-items", "Remove occasionally produced cold items", () => removeOccasionalColdItems = true);

            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            var splittingMapping = Mappings.StarRatingRecommender.SplitToTrainTest(
                trainingOnlyUserFraction,
                testUserRatingTrainingFraction,
                coldUserFraction,
                coldItemFraction,
                ignoredUserFraction,
                ignoredItemFraction,
                removeOccasionalColdItems);

            var inputDataset          = RecommenderDataset.Load(inputDatasetFile);
            var outputTrainingDataset = new RecommenderDataset(
                splittingMapping.GetInstances(SplitInstanceSource.Training(inputDataset)),
                inputDataset.StarRatingInfo);

            outputTrainingDataset.Save(outputTrainingDatasetFile);
            var outputTestDataset = new RecommenderDataset(
                splittingMapping.GetInstances(SplitInstanceSource.Test(inputDataset)),
                inputDataset.StarRatingInfo);

            outputTestDataset.Save(outputTestDatasetFile);

            return(true);
        }
Exemple #6
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string trainingSetFile      = string.Empty;
            string modelFile            = string.Empty;
            int    iterationCount       = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
            int    batchCount           = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
            bool   computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--model", "FILE", "File to store the trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount  = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence      = true);

            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);

            BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);

            var featureSet = trainingSet.Count > 0 ? trainingSet.First().FeatureSet : null;
            var mapping    = new ClassifierMapping(featureSet);
            var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(mapping);

            classifier.Settings.Training.IterationCount       = iterationCount;
            classifier.Settings.Training.BatchCount           = batchCount;
            classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;

            classifier.Train(trainingSet);

            if (classifier.Settings.Training.ComputeModelEvidence)
            {
                Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
            }

            classifier.Save(modelFile);

            return(true);
        }
Exemple #7
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string modelFile   = string.Empty;
            string samplesFile = string.Empty;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--model", "FILE", "File with a trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--samples", "FILE", "File to store samples of the weights", v => samplesFile             = v, CommandLineParameterType.Required);
            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            var classifier =
                BayesPointMachineClassifier.LoadBinaryClassifier <IList <LabeledFeatureValues>, LabeledFeatureValues, IList <LabelDistribution>, string, IDictionary <string, double> >(modelFile);

            BayesPointMachineClassifierModuleUtilities.SampleWeights(classifier, samplesFile);

            return(true);
        }
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string datasetFile = string.Empty;
            string trainedModelFile = string.Empty;
            string predictionsFile = string.Empty;
            
            var parser = new CommandLineParser();
            parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
            if (!parser.TryParse(args, usagePrefix))
            {
                return false;
            }

            RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
            
            var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, DummyFeatureSource>(trainedModelFile);
            IDictionary<User, IDictionary<Item, int>> predictions = trainedModel.Predict(testDataset);
            RecommenderPersistenceUtils.SavePredictedRatings(predictionsFile, predictions);

            return true;
        }
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string testDatasetFile = string.Empty;
            string predictionsFile = string.Empty;
            string reportFile      = string.Empty;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile            = v, CommandLineParameterType.Required);
            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
            int minRating = Mappings.StarRatingRecommender.GetRatingInfo(testDataset).MinStarRating;

            IDictionary <User, IEnumerable <Item> > recommendedItems = RecommenderPersistenceUtils.LoadRecommendedItems(predictionsFile);

            var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
            var evaluator        = new StarRatingRecommenderEvaluator <RecommenderDataset, User, Item, int>(evaluatorMapping);

            using (var writer = new StreamWriter(reportFile))
            {
                writer.WriteLine(
                    "NDCG: {0:0.000}",
                    evaluator.ItemRecommendationMetric(
                        testDataset,
                        recommendedItems,
                        Metrics.Ndcg,
                        rating => Convert.ToDouble(rating) - minRating + 1));
            }

            return(true);
        }
Exemple #10
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string testDatasetFile      = string.Empty;
            string predictionsFile      = string.Empty;
            string reportFile           = string.Empty;
            int    minCommonRatingCount = 5;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile            = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--min-common-items", "NUM", "Minimum number of users that the query item and the related item should have been rated by in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
            IDictionary <Item, IEnumerable <Item> > relatedItems = RecommenderPersistenceUtils.LoadRelatedItems(predictionsFile);

            var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
            var evaluator        = new StarRatingRecommenderEvaluator <RecommenderDataset, User, Item, int>(evaluatorMapping);

            using (var writer = new StreamWriter(reportFile))
            {
                writer.WriteLine(
                    "L1 Sim NDCG: {0:0.000}",
                    evaluator.RelatedItemsMetric(testDataset, relatedItems, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedManhattanSimilarity));
                writer.WriteLine(
                    "L2 Sim NDCG: {0:0.000}",
                    evaluator.RelatedItemsMetric(testDataset, relatedItems, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedEuclideanSimilarity));
            }

            return(true);
        }
Exemple #11
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string trainingSetFile = string.Empty;
            string inputModelFile  = string.Empty;
            string outputModelFile = string.Empty;
            int    iterationCount  = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
            int    batchCount      = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--input-model", "FILE", "File with the trained multi-class Bayes point machine model", v => inputModelFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--model", "FILE", "File to store the incrementally trained multi-class Bayes point machine model", v => outputModelFile   = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount  = v, CommandLineParameterType.Optional);

            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);

            BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);

            var classifier = BayesPointMachineClassifier.LoadMulticlassClassifier <IList <LabeledFeatureValues>, LabeledFeatureValues, IList <LabelDistribution>, string, IDictionary <string, double> >(inputModelFile);

            classifier.Settings.Training.IterationCount = iterationCount;
            classifier.Settings.Training.BatchCount     = batchCount;

            classifier.TrainIncremental(trainingSet);

            classifier.Save(outputModelFile);

            return(true);
        }
Exemple #12
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string groundTruthFileName          = string.Empty;
            string predictionsFileName          = string.Empty;
            string reportFileName               = string.Empty;
            string calibrationCurveFileName     = string.Empty;
            string rocCurveFileName             = string.Empty;
            string precisionRecallCurveFileName = string.Empty;
            string positiveClassLabel           = string.Empty;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--ground-truth", "FILE", "File with ground truth labels", v => groundTruthFileName = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--predictions", "FILE", "File with label predictions", v => predictionsFileName    = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--report", "FILE", "File to store the evaluation report", v => reportFileName      = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--calibration-curve", "FILE", "File to store the empirical calibration curve", v => calibrationCurveFileName     = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--roc-curve", "FILE", "File to store the receiver operating characteristic curve", v => rocCurveFileName         = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--precision-recall-curve", "FILE", "File to store the precision-recall curve", v => precisionRecallCurveFileName = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--positive-class", "STRING", "Label of the positive class to use in curves", v => positiveClassLabel             = v, CommandLineParameterType.Optional);
            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            // Read ground truth
            var groundTruth = ClassifierPersistenceUtils.LoadLabeledFeatureValues(groundTruthFileName);

            // Read predictions using ground truth label dictionary
            var predictions = ClassifierPersistenceUtils.LoadLabelDistributions(predictionsFileName, groundTruth.First().LabelDistribution.LabelSet);

            // Check that there are at least two distinct class labels
            if (predictions.First().LabelSet.Count < 2)
            {
                throw new InvalidFileFormatException("Ground truth and predictions must contain at least two distinct class labels.");
            }

            // Distill distributions and point estimates
            var predictiveDistributions  = predictions.Select(i => i.ToDictionary()).ToList();
            var predictivePointEstimates = predictions.Select(i => i.GetMode()).ToList();

            // Create evaluator
            var evaluatorMapping = Mappings.Classifier.ForEvaluation();
            var evaluator        = new ClassifierEvaluator <IList <LabeledFeatureValues>, LabeledFeatureValues, IList <LabelDistribution>, string>(evaluatorMapping);

            // Write evaluation report
            if (!string.IsNullOrEmpty(reportFileName))
            {
                using (var writer = new StreamWriter(reportFileName))
                {
                    this.WriteReportHeader(writer, groundTruthFileName, predictionsFileName);
                    this.WriteReport(writer, evaluator, groundTruth, predictiveDistributions, predictivePointEstimates);
                }
            }

            // Compute and write the empirical probability calibration curve
            positiveClassLabel = this.CheckPositiveClassLabel(groundTruth, positiveClassLabel);
            if (!string.IsNullOrEmpty(calibrationCurveFileName))
            {
                this.WriteCalibrationCurve(calibrationCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
            }

            // Compute and write the precision-recall curve
            if (!string.IsNullOrEmpty(precisionRecallCurveFileName))
            {
                this.WritePrecisionRecallCurve(precisionRecallCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
            }

            // Compute and write the receiver operating characteristic curve
            if (!string.IsNullOrEmpty(rocCurveFileName))
            {
                this.WriteRocCurve(rocCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
            }

            return(true);
        }
Exemple #13
0
        /// <summary>
        /// Runs the module.
        /// </summary>
        /// <param name="args">The command line arguments for the module.</param>
        /// <param name="usagePrefix">The prefix to print before the usage string.</param>
        /// <returns>True if the run was successful, false otherwise.</returns>
        public override bool Run(string[] args, string usagePrefix)
        {
            string dataSetFile = string.Empty;
            string resultsFile = string.Empty;
            int    crossValidationFoldCount = 5;
            int    iterationCount           = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
            int    batchCount           = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
            bool   computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;

            var parser = new CommandLineParser();

            parser.RegisterParameterHandler("--data-set", "FILE", "File with training data", v => dataSetFile           = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--results", "FILE", "File with cross-validation results", v => resultsFile = v, CommandLineParameterType.Required);
            parser.RegisterParameterHandler("--folds", "NUM", "Number of cross-validation folds (defaults to " + crossValidationFoldCount + ")", v => crossValidationFoldCount = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount         = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount          = v, CommandLineParameterType.Optional);
            parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);

            if (!parser.TryParse(args, usagePrefix))
            {
                return(false);
            }

            // Load and shuffle data
            var dataSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(dataSetFile);

            BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(dataSet);

            Rand.Restart(562);
            Rand.Shuffle(dataSet);

            // Create evaluator
            var evaluatorMapping = Mappings.Classifier.ForEvaluation();
            var evaluator        = new ClassifierEvaluator <IList <LabeledFeatureValues>, LabeledFeatureValues, IList <LabelDistribution>, string>(evaluatorMapping);

            // Create performance metrics
            var accuracy = new List <double>();
            var negativeLogProbability = new List <double>();
            var auc             = new List <double>();
            var evidence        = new List <double>();
            var iterationCounts = new List <double>();
            var trainingTime    = new List <double>();

            // Run cross-validation
            int validationSetSize = dataSet.Count / crossValidationFoldCount;

            Console.WriteLine("Running {0}-fold cross-validation on {1}", crossValidationFoldCount, dataSetFile);

            // TODO: Use chained mapping to implement cross-validation
            for (int fold = 0; fold < crossValidationFoldCount; fold++)
            {
                // Construct training and validation sets for fold
                int validationSetStart = fold * validationSetSize;
                int validationSetEnd   = (fold + 1 == crossValidationFoldCount)
                                           ? dataSet.Count
                                           : (fold + 1) * validationSetSize;

                var trainingSet   = new List <LabeledFeatureValues>();
                var validationSet = new List <LabeledFeatureValues>();

                for (int instance = 0; instance < dataSet.Count; instance++)
                {
                    if (validationSetStart <= instance && instance < validationSetEnd)
                    {
                        validationSet.Add(dataSet[instance]);
                    }
                    else
                    {
                        trainingSet.Add(dataSet[instance]);
                    }
                }

                // Print info
                Console.WriteLine("   Fold {0} [validation set instances {1} - {2}]", fold + 1, validationSetStart, validationSetEnd - 1);

                // Create classifier
                var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(Mappings.Classifier);
                classifier.Settings.Training.IterationCount       = iterationCount;
                classifier.Settings.Training.BatchCount           = batchCount;
                classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;

                int currentIterationCount = 0;
                classifier.IterationChanged += (sender, eventArgs) => { currentIterationCount = eventArgs.CompletedIterationCount; };

                // Train classifier
                var stopWatch = new Stopwatch();
                stopWatch.Start();
                classifier.Train(trainingSet);
                stopWatch.Stop();

                // Produce predictions
                var predictions     = classifier.PredictDistribution(validationSet).ToList();
                var predictedLabels = predictions.Select(
                    prediction => prediction.Aggregate((aggregate, next) => next.Value > aggregate.Value ? next : aggregate).Key).ToList();

                // Iteration count
                iterationCounts.Add(currentIterationCount);

                // Training time
                trainingTime.Add(stopWatch.ElapsedMilliseconds);

                // Compute accuracy
                accuracy.Add(1 - (evaluator.Evaluate(validationSet, predictedLabels, Metrics.ZeroOneError) / predictions.Count));

                // Compute mean negative log probability
                negativeLogProbability.Add(evaluator.Evaluate(validationSet, predictions, Metrics.NegativeLogProbability) / predictions.Count);

                // Compute M-measure (averaged pairwise AUC)
                auc.Add(evaluator.AreaUnderRocCurve(validationSet, predictions));

                // Compute log evidence if desired
                evidence.Add(computeModelEvidence ? classifier.LogModelEvidence : double.NaN);

                // Persist performance metrics
                Console.WriteLine(
                    "      Accuracy = {0,5:0.0000}   NegLogProb = {1,5:0.0000}   AUC = {2,5:0.0000}{3}   Iterations = {4}   Training time = {5}",
                    accuracy[fold],
                    negativeLogProbability[fold],
                    auc[fold],
                    computeModelEvidence ? string.Format("   Log evidence = {0,5:0.0000}", evidence[fold]) : string.Empty,
                    iterationCounts[fold],
                    BayesPointMachineClassifierModuleUtilities.FormatElapsedTime(trainingTime[fold]));

                BayesPointMachineClassifierModuleUtilities.SavePerformanceMetrics(
                    resultsFile, accuracy, negativeLogProbability, auc, evidence, iterationCounts, trainingTime);
            }

            return(true);
        }