/// <summary> /// Runs the generated prediction algorithm for the specified features. /// </summary> /// <param name="featureValues">The feature values.</param> /// <param name="featureIndexes">The feature indexes.</param> /// <param name="iterationCount">The number of iterations to run the prediction algorithm for.</param> /// <returns>The predictive distributions over labels.</returns> public IEnumerable <TLabelDistribution> PredictDistribution(double[][] featureValues, int[][] featureIndexes, int iterationCount) { InferenceAlgorithmUtilities.CheckIterationCount(iterationCount); InferenceAlgorithmUtilities.CheckFeatures(this.UseSparseFeatures, this.FeatureCount, featureValues, featureIndexes); // Update prior weight distributions of the prediction algorithm to the posterior weight distributions of the training algorithm this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.WeightPriors, this.WeightMarginals); // Infer posterior distribution over labels, one instance after the other for (int i = 0; i < featureValues.Length; i++) { // Observe a single feature vector this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.InstanceCount, 1); if (this.UseSparseFeatures) { this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.InstanceFeatureCounts, new[] { featureValues[i].Length }); this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureIndexes, new[] { featureIndexes[i] }); } this.PredictionAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureValues, new[] { featureValues[i] }); // Infer the posterior distribution over the label yield return(this.CopyLabelDistribution(this.RunPredictionAlgorithm(iterationCount))); } }
/// <summary> /// Runs the generated training algorithm for the specified features and labels. /// </summary> /// <param name="featureValues">The feature values.</param> /// <param name="featureIndexes">The feature indexes.</param> /// <param name="labels">The labels.</param> /// <param name="iterationCount">The number of iterations to run the training algorithm for.</param> /// <param name="batchNumber"> /// An optional batch number. Defaults to 0 and is used only if the training data is divided into batches. /// </param> protected virtual void TrainInternal(double[][] featureValues, int[][] featureIndexes, TLabel[] labels, int iterationCount, int batchNumber = 0) { InferenceAlgorithmUtilities.CheckIterationCount(iterationCount); InferenceAlgorithmUtilities.CheckBatchNumber(batchNumber, this.BatchCount); InferenceAlgorithmUtilities.CheckFeatures(this.UseSparseFeatures, this.FeatureCount, featureValues, featureIndexes); Debug.Assert(featureValues.Length == labels.Length, "There must be the same number of feature values and labels."); // Observe features and labels this.TrainingAlgorithm.SetObservedValue(InferenceQueryVariableNames.InstanceCount, featureValues.Length); if (this.UseSparseFeatures) { this.TrainingAlgorithm.SetObservedValue( InferenceQueryVariableNames.InstanceFeatureCounts, Util.ArrayInit(featureValues.Length, instance => featureValues[instance].Length)); this.TrainingAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureIndexes, featureIndexes); } this.TrainingAlgorithm.SetObservedValue(InferenceQueryVariableNames.FeatureValues, featureValues); this.Labels = labels; if (this.BatchCount == 1) { // Required for incremental training this.WeightConstraints = this.WeightMarginalsDividedByPriors; // Run the training algorithm this.RunTrainingAlgorithm(iterationCount); } else { // Print information about current batch to console if (this.ShowProgress) { Console.WriteLine("Batch {0} [{1} instance{2}]", batchNumber + 1, featureValues.Length, featureValues.Length == 1 ? string.Empty : "s"); } // Compute the constraint distributions for the weights for the given batch this.WeightConstraints.SetToRatio(this.WeightMarginalsDividedByPriors, this.BatchWeightOutputMessages[batchNumber]); // Run the training algorithm this.RunTrainingAlgorithm(iterationCount); // Update the output messages for the weights for the given batch this.BatchWeightOutputMessages[batchNumber].SetToRatio(this.WeightMarginalsDividedByPriors, this.WeightConstraints); } }