예제 #1
0
    private void Train(
        string outputModelFileName,
        int iterationCount,
        bool computeModelEvidence,
        int batchCount)
    {
        // Set settings
        _classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
        _classifier.Settings.Training.IterationCount       = iterationCount;
        _classifier.Settings.Training.BatchCount           = batchCount;

        // train
        _classifier.Train(_x, _y);
        _classifier.Save(outputModelFileName);
    }
예제 #2
0
    public override void TrainIncremental(
        string inputModelFileName,
        string outputModelFileName,
        int iterationCount,
        bool computeModelEvidence,
        int batchCount,
        DistributionName distributionName,
        InferenceAlgorithm inferenceEngineAlgorithm,
        double noise)
    {
        TraceListeners.Log(TraceEventType.Warning, 0,
                           "Advanced setting will not be used: " +
                           "distributionName, inferenceEngineAlgorithm & noise.", false, true);

        // Validate
        _validate.TrainIncremental(
            inputModelFileName: inputModelFileName,
            outputModelFileName: outputModelFileName,
            iterationCount: iterationCount,
            batchCount: batchCount);

        // Load model
        IBayesPointMachineClassifier <
            IList <Vector>, int, IList <string>, string, IDictionary <string, double>,
            BayesPointMachineClassifierTrainingSettings,
            BinaryBayesPointMachineClassifierPredictionSettings <string> > classifier =
            BayesPointMachineClassifier.LoadBinaryClassifier <
                IList <Vector>, int, IList <string>, string, IDictionary <string, double> >
                (inputModelFileName);

        // Set settings
        classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
        classifier.Settings.Training.IterationCount       = iterationCount;
        classifier.Settings.Training.BatchCount           = batchCount;

        // train
        classifier.TrainIncremental(_x, _y);
        classifier.Save(outputModelFileName);
    }
예제 #3
0
        /// <summary>
        /// Diagnoses the Bayes point machine classifier on the specified data set.
        /// </summary>
        /// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
        /// <param name="classifier">The Bayes point machine classifier.</param>
        /// <param name="trainingSet">The dataset.</param>
        /// <param name="maxParameterChangesFileName">The name of the file to store the maximum parameter differences.</param>
        /// <param name="modelFileName">The name of the file to store the trained Bayes point machine model.</param>
        public static void DiagnoseClassifier <TTrainingSettings>(
            IBayesPointMachineClassifier <IList <LabeledFeatureValues>, LabeledFeatureValues, IList <LabelDistribution>, string, IDictionary <string, double>, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings <string> > classifier,
            IList <LabeledFeatureValues> trainingSet,
            string maxParameterChangesFileName,
            string modelFileName)
            where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
        {
            // Create prior distributions over weights
            int classCount               = trainingSet[0].LabelDistribution.LabelSet.Count;
            int featureCount             = trainingSet[0].GetDenseFeatureVector().Count;
            var priorWeightDistributions = Util.ArrayInit(classCount, c => Util.ArrayInit(featureCount, f => new Gaussian(0.0, 1.0)));

            // Create IterationChanged handler
            var watch = new Stopwatch();

            classifier.IterationChanged += (sender, eventArgs) =>
            {
                watch.Stop();
                double maxParameterChange = MaxDiff(eventArgs.WeightPosteriorDistributions, priorWeightDistributions);

                if (!string.IsNullOrEmpty(maxParameterChangesFileName))
                {
                    SaveMaximumParameterDifference(
                        maxParameterChangesFileName,
                        eventArgs.CompletedIterationCount,
                        maxParameterChange,
                        watch.ElapsedMilliseconds);
                }

                Console.WriteLine(
                    "[{0}] Iteration {1,-4}   dp = {2,-20}   dt = {3,5}ms",
                    DateTime.Now.ToLongTimeString(),
                    eventArgs.CompletedIterationCount,
                    maxParameterChange,
                    watch.ElapsedMilliseconds);

                // Copy weight marginals
                for (int c = 0; c < eventArgs.WeightPosteriorDistributions.Count; c++)
                {
                    for (int f = 0; f < eventArgs.WeightPosteriorDistributions[c].Count; f++)
                    {
                        priorWeightDistributions[c][f] = eventArgs.WeightPosteriorDistributions[c][f];
                    }
                }

                watch.Restart();
            };

            // Write file header
            if (!string.IsNullOrEmpty(maxParameterChangesFileName))
            {
                using (var writer = new StreamWriter(maxParameterChangesFileName))
                {
                    writer.WriteLine("# time, # iteration, # maximum absolute parameter difference, # iteration time in milliseconds");
                }
            }

            // Train the Bayes point machine classifier
            Console.WriteLine("[{0}] Starting training...", DateTime.Now.ToLongTimeString());
            watch.Start();

            classifier.Train(trainingSet);

            // Compute evidence
            if (classifier.Settings.Training.ComputeModelEvidence)
            {
                Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
            }

            // Save trained model
            if (!string.IsNullOrEmpty(modelFileName))
            {
                classifier.Save(modelFileName);
            }
        }