Example #1
0
        /// <summary>
        /// Runs the generated prediction algorithm for the specified features.
        /// </summary>
        /// <param name="featureValues">The feature values.</param>
        /// <param name="featureIndexes">The feature indexes.</param>
        /// <param name="iterationCount">The number of iterations to run the prediction algorithm for.</param>
        /// <returns>The predictive distributions over labels.</returns>
        public IEnumerable <TLabelDistribution> PredictDistribution(double[][] featureValues, int[][] featureIndexes, int iterationCount)
        {
            InferenceAlgorithmUtilities.CheckIterationCount(iterationCount);
            InferenceAlgorithmUtilities.CheckFeatures(this.UseSparseFeatures, this.FeatureCount, featureValues, featureIndexes);

            // Update prior weight distributions of the prediction algorithm to the posterior weight distributions of the training algorithm
            this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.WeightPriors, this.WeightMarginals);

            // Infer posterior distribution over labels, one instance after the other
            for (int i = 0; i < featureValues.Length; i++)
            {
                // Observe a single feature vector
                this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.InstanceCount, 1);

                if (this.UseSparseFeatures)
                {
                    this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.InstanceFeatureCounts, new[] { featureValues[i].Length });
                    this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureIndexes, new[] { featureIndexes[i] });
                }

                this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureValues, new[] { featureValues[i] });

                // Infer the posterior distribution over the label
                yield return(this.CopyLabelDistribution(this.RunPredictionAlgorithm(iterationCount)));
            }
        }
Example #2
0
        /// <summary>
        /// Runs the generated training algorithm for the specified features and labels.
        /// </summary>
        /// <param name="featureValues">The feature values.</param>
        /// <param name="featureIndexes">The feature indexes.</param>
        /// <param name="labels">The labels.</param>
        /// <param name="iterationCount">The number of iterations to run the training algorithm for.</param>
        /// <param name="batchNumber">
        /// An optional batch number. Defaults to 0 and is used only if the training data is divided into batches.
        /// </param>
        protected virtual void TrainInternal(double[][] featureValues, int[][] featureIndexes, TLabel[] labels, int iterationCount, int batchNumber = 0)
        {
            InferenceAlgorithmUtilities.CheckIterationCount(iterationCount);
            InferenceAlgorithmUtilities.CheckBatchNumber(batchNumber, this.BatchCount);
            InferenceAlgorithmUtilities.CheckFeatures(this.UseSparseFeatures, this.FeatureCount, featureValues, featureIndexes);
            Debug.Assert(featureValues.Length == labels.Length, "There must be the same number of feature values and labels.");

            // Observe features and labels
            this.TrainingAlgorithm.SetObservedValue(InferenceQueryVariableNames.InstanceCount, featureValues.Length);

            if (this.UseSparseFeatures)
            {
                this.TrainingAlgorithm.SetObservedValue(
                    InferenceQueryVariableNames.InstanceFeatureCounts,
                    Util.ArrayInit(featureValues.Length, instance => featureValues[instance].Length));
                this.TrainingAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureIndexes, featureIndexes);
            }

            this.TrainingAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureValues, featureValues);
            this.Labels = labels;

            if (this.BatchCount == 1)
            {
                // Required for incremental training
                this.WeightConstraints = this.WeightMarginalsDividedByPriors;

                // Run the training algorithm
                this.RunTrainingAlgorithm(iterationCount);
            }
            else
            {
                // Print information about current batch to console
                if (this.ShowProgress)
                {
                    Console.WriteLine("Batch {0} [{1} instance{2}]", batchNumber + 1, featureValues.Length, featureValues.Length == 1 ? string.Empty : "s");
                }

                // Compute the constraint distributions for the weights for the given batch
                this.WeightConstraints.SetToRatio(this.WeightMarginalsDividedByPriors, this.BatchWeightOutputMessages[batchNumber]);

                // Run the training algorithm
                this.RunTrainingAlgorithm(iterationCount);

                // Update the output messages for the weights for the given batch
                this.BatchWeightOutputMessages[batchNumber].SetToRatio(this.WeightMarginalsDividedByPriors, this.WeightConstraints);
            }
        }