예제 #1
0
파일: BPM.cs 프로젝트: mesgarpour/ERMER
    public BPM(
        string[] labels,
        double sparsityApproxThresh)
    {
        Debug.Assert(labels != null, "The labels must not be null.");
        Debug.Assert(labels.Length == 2, "The labels must have two possible values.");
        Debug.Assert(sparsityApproxThresh >= 0, "The sparsityApproxThresh must be greater than or equal to zero.");

        // Initialise the validations
        _validate = new Validate();

        // Initialise the BPM
        _engine = new Dictionary <DistributionType, InferenceEngine>();
        _w      = new Dictionary <DistributionType, Variable <Vector> >();
        _w[DistributionType.Prior]     = null;
        _w[DistributionType.Posterior] = null;
        _d = new Dictionary <DistributionType, DistributionName>();
        _yPredicDistrib = Enumerable.Empty <IDictionary <string, double> >();
        _yPredicLabel   = new string[] { };

        _mapping = new GenericClassifierMapping(labels);
        // TO DO

        // Evaluator mapping
        var evaluatorMapping = _mapping.ForEvaluation();

        _evaluator = new ClassifierEvaluator
                     <IList <Vector>, int, IList <string>, string>(evaluatorMapping);

        // Other initialisations
        _availableDatasetName = new DatasetName();
        _numObservations      = 0;
        _numFeatures          = 0;
    }
예제 #2
0
    public BPMMapped(
        string[] labels)
    {
        Debug.Assert(labels != null, "The labels must not be null.");
        Debug.Assert(labels.Length == 2, "The labels must have two possible values.");

        // Initialise the validations
        _validate = new Validate();

        // Create a BPM from the mapping
        _mapping    = new GenericClassifierMapping(labels);
        _classifier = BayesPointMachineClassifier.CreateBinaryClassifier(_mapping);

        // Evaluator mapping
        var evaluatorMapping = _mapping.ForEvaluation();

        _evaluator = new ClassifierEvaluator
                     <IList <Vector>, int, IList <string>, string>(evaluatorMapping);

        // Other initialisations
        _availableDatasetName = new DatasetName();
        _numObservations      = 0;
        _numFeatures          = 0;
    }
예제 #3
0
    /// <summary>
    /// CrossValidate diagnosis
    /// </summary>
    /// <param name="x"></param>
    /// <param name="y"></param>
    /// <param name="mapping"></param>
    /// <param name="reportFileName"></param>
    /// <param name="crossValidationFoldCount"></param>
    /// <param name="iterationCount"></param>
    /// <param name="computeModelEvidence"></param>
    /// <param name="batchCount"></param>
    /// <remarks>Adapted from MicrosoftResearch.Infer.Learners</remarks>
    public CrossValidateMapped(
        Vector[] x,
        IList <string> y,
        GenericClassifierMapping mapping,
        string reportFileName,
        int crossValidationFoldCount, //folds
        int iterationCount,
        bool computeModelEvidence,
        int batchCount)
    {
        Debug.Assert(x != null, "The feature vector must not be null.");
        Debug.Assert(y != null, "The targe variable must not be null.");
        Debug.Assert(mapping != null, "The classifier map must not be null.");
        Debug.Assert(!string.IsNullOrEmpty(reportFileName), "The report file name must not be null/empty.");
        Debug.Assert(iterationCount > 0, "The iteration count must be greater than zero.");
        Debug.Assert(batchCount > 0, "The batch count must be greater than zero.");

        // Shuffle dataset
        shuffleVector(x);

        // Create evaluator
        var evaluatorMapping = mapping.ForEvaluation();
        var evaluator        = new ClassifierEvaluator <
            IList <Vector>,         // the type of the instance source,
            int,                    // the type of an instance
            IList <string>,         // the type of the label source
            string>(                // the type of a label.
            evaluatorMapping);


        // Create performance metrics
        var accuracy = new List <double>();
        var negativeLogProbability = new List <double>();
        var auc             = new List <double>();
        var evidence        = new List <double>();
        var iterationCounts = new List <double>();
        var trainingTime    = new List <double>();

        // Run cross-validation
        int validationSetSize     = x.Length / crossValidationFoldCount;
        int trainingSetSize       = x.Length - validationSetSize;
        int validationFoldSetSize = 0;
        int trainingFoldSetSize   = 0;

        Console.WriteLine(
            "Running {0}-fold cross-validation", crossValidationFoldCount);

        if (validationSetSize == 0 || trainingSetSize == 0)
        {
            Console.WriteLine("Invalid number of folds");
            Console.ReadKey();
            System.Environment.Exit(1);
        }

        for (int fold = 0; fold < crossValidationFoldCount; fold++)
        {
            // Construct training and validation sets for fold
            int validationSetStart = fold * validationSetSize;
            int validationSetEnd   = (fold + 1 == crossValidationFoldCount)
                                       ? x.Length
                                       : (fold + 1) * validationSetSize;


            validationFoldSetSize = validationSetEnd - validationSetStart;
            trainingFoldSetSize   = x.Length - validationFoldSetSize;

            Vector[]       trainingSet         = new Vector[trainingFoldSetSize];
            Vector[]       validationSet       = new Vector[validationFoldSetSize];
            IList <string> trainingSetLabels   = new List <string>();
            IList <string> validationSetLabels = new List <string>();

            for (int instance = 0, iv = 0, it = 0; instance < x.Length; instance++)
            {
                if (validationSetStart <= instance && instance < validationSetEnd)
                {
                    validationSet[iv++] = x[instance];
                    validationSetLabels.Add(y[instance]);
                }
                else
                {
                    trainingSet[it++] = x[instance];
                    trainingSetLabels.Add(y[instance]);
                }
            }

            // Print info
            Console.WriteLine("   Fold {0} [validation set instances {1} - {2}]", fold + 1, validationSetStart, validationSetEnd - 1);

            // Create classifier
            var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(mapping);
            classifier.Settings.Training.IterationCount       = iterationCount;
            classifier.Settings.Training.BatchCount           = batchCount;
            classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;

            int currentIterationCount = 0;
            classifier.IterationChanged += (sender, eventArgs) => { currentIterationCount = eventArgs.CompletedIterationCount; };

            // Train classifier
            var stopWatch = new Stopwatch();
            stopWatch.Start();
            classifier.Train(trainingSet, trainingSetLabels);
            stopWatch.Stop();

            // Produce predictions
            IEnumerable <IDictionary <string, double> > predictions =
                classifier.PredictDistribution(validationSet);
            var predictedLabels = classifier.Predict(validationSet);

            // Iteration count
            iterationCounts.Add(currentIterationCount);

            // Training time
            trainingTime.Add(stopWatch.ElapsedMilliseconds);

            // Compute accuracy
            accuracy.Add(1 - (evaluator.Evaluate(validationSet, validationSetLabels, predictedLabels, Metrics.ZeroOneError) / predictions.Count()));

            // Compute mean negative log probability
            negativeLogProbability.Add(evaluator.Evaluate(validationSet, validationSetLabels, predictions, Metrics.NegativeLogProbability) / predictions.Count());

            // Compute M-measure (averaged pairwise AUC)
            auc.Add(evaluator.AreaUnderRocCurve(validationSet, validationSetLabels, predictions));

            // Compute log evidence if desired
            evidence.Add(computeModelEvidence ? classifier.LogModelEvidence : double.NaN);

            // Persist performance metrics
            Console.WriteLine(
                "      Accuracy = {0,5:0.0000}   NegLogProb = {1,5:0.0000}   AUC = {2,5:0.0000}{3}   Iterations = {4}   Training time = {5}",
                accuracy[fold],
                negativeLogProbability[fold],
                auc[fold],
                computeModelEvidence ? string.Format("   Log evidence = {0,5:0.0000}", evidence[fold]) : string.Empty,
                iterationCounts[fold],
                FormatElapsedTime(trainingTime[fold]));

            SavePerformanceMetrics(
                reportFileName, accuracy, negativeLogProbability, auc, evidence, iterationCounts, trainingTime);
        }
    }
예제 #4
0
    /// <summary>
    /// Diagnoses the Bayes point machine classifier on the specified data set.
    /// </summary>
    /// <param name="x"></param>
    /// <param name="y"></param>
    /// <param name="mapping"></param>
    /// <param name="reportFileName">The name of the file to store the maximum parameter differences.</param>
    /// <param name="outputModelFileName">The name of the file to store the trained Bayes point machine model.</param>
    /// <param name="iterationCount"></param>
    /// <param name="computeModelEvidence"></param>
    /// <param name="batchCount"></param>
    /// <remarks>Adapted from MicrosoftResearch.Infer.Learners</remarks>
    public void DiagnoseClassifier(
        Vector[] x,
        IList <string> y,
        GenericClassifierMapping mapping,
        string outputModelFileName,
        string reportFileName,
        int iterationCount,
        bool computeModelEvidence,
        int batchCount)
    {
        Debug.Assert(x != null, "The feature vector must not be null.");
        Debug.Assert(y != null, "The targe variable must not be null.");
        Debug.Assert(mapping != null, "The classifier map must not be null.");
        Debug.Assert(!string.IsNullOrEmpty(reportFileName), "The report file name must not be null/empty.");
        Debug.Assert(iterationCount > 0, "The iteration count must be greater than zero.");
        Debug.Assert(batchCount > 0, "The batch count must be greater than zero.");

        // create a BPM from the mapping
        var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(mapping);

        classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
        classifier.Settings.Training.IterationCount       = iterationCount;
        classifier.Settings.Training.BatchCount           = batchCount;

        // Create prior distributions over weights
        Dictionary <int, double[]> maxMean;
        Dictionary <int, double[]> maxVar;
        int classCount               = 2;
        int featureCount             = x.Length;
        var priorWeightDistributions = Util.ArrayInit(classCount, c => Util.ArrayInit(featureCount, f => new Gaussian(0.0, 1.0)));

        // Create IterationChanged handler
        var watch = new Stopwatch();

        classifier.IterationChanged += (sender, eventArgs) =>
        {
            watch.Stop();
            double maxParameterChange = MaxDiff(eventArgs.WeightPosteriorDistributions, priorWeightDistributions, out maxMean, out maxVar);

            if (!string.IsNullOrEmpty(reportFileName))
            {
                SaveMaximumParameterDifference(
                    reportFileName,
                    eventArgs.CompletedIterationCount,
                    maxParameterChange,
                    watch.ElapsedMilliseconds,
                    maxMean,
                    maxVar);
            }

            Console.WriteLine(
                "[{0}] Iteration {1,-4}   dp = {2,-20}   dt = {3,5}ms",
                DateTime.Now.ToLongTimeString(),
                eventArgs.CompletedIterationCount,
                maxParameterChange,
                watch.ElapsedMilliseconds);

            // Copy weight marginals
            for (int c = 0; c < eventArgs.WeightPosteriorDistributions.Count; c++)
            {
                for (int f = 0; f < eventArgs.WeightPosteriorDistributions[c].Count; f++)
                {
                    priorWeightDistributions[c][f] = eventArgs.WeightPosteriorDistributions[c][f];
                }
            }

            watch.Restart();
        };

        // Write file header
        if (!string.IsNullOrEmpty(reportFileName))
        {
            using (var writer = new StreamWriter(reportFileName))
            {
                writer.WriteLine("# time, # iteration, " +
                                 "# maximum absolute parameter difference, " +
                                 "# iteration time in milliseconds, " +
                                 "# Max Mean, # Max Var.");
            }
        }

        // Train the Bayes point machine classifier
        Console.WriteLine("[{0}] Starting training...", DateTime.Now.ToLongTimeString());
        watch.Start();

        classifier.Train(x, y);

        // Compute evidence
        if (classifier.Settings.Training.ComputeModelEvidence)
        {
            Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
        }

        // Save trained model
        if (!string.IsNullOrEmpty(outputModelFileName))
        {
            classifier.Save(outputModelFileName);
        }
    }