public virtual void TestCalculateMarginals(GraphicalModel model, ConcatVector weights) { CliqueTree inference = new CliqueTree(model, weights); // This is the basic check that inference works when you first construct the model CheckMarginalsAgainstBruteForce((GraphicalModel)model, (ConcatVector)weights, inference); // Now we go through several random mutations to the model, and check that everything is still consistent Random r = new Random(); for (int i = 0; i < 10; i++) { RandomlyMutateGraphicalModel((GraphicalModel)model, r); CheckMarginalsAgainstBruteForce((GraphicalModel)model, (ConcatVector)weights, inference); } }
public virtual void CheckMAPAgainstBruteForce(GraphicalModel model, ConcatVector weights, CliqueTree inference) { int[] map = inference.CalculateMAP(); ICollection <TableFactor> tableFactors = model.factors.Stream().Map(null).Collect(Collectors.ToSet()); // this is the super slow but obviously correct way to get global marginals TableFactor bruteForce = null; foreach (TableFactor factor in tableFactors) { if (bruteForce == null) { bruteForce = factor; } else { bruteForce = bruteForce.Multiply(factor); } } System.Diagnostics.Debug.Assert((bruteForce != null)); // observe out all variables that have been registered TableFactor observed = bruteForce; foreach (int n in bruteForce.neighborIndices) { if (model.GetVariableMetaDataByReference(n).Contains(CliqueTree.VariableObservedValue)) { int value = System.Convert.ToInt32(model.GetVariableMetaDataByReference(n)[CliqueTree.VariableObservedValue]); if (observed.neighborIndices.Length > 1) { observed = observed.Observe(n, value); } else { // If we've observed everything, then just quit return; } } } bruteForce = observed; int largestVariableNum = 0; foreach (GraphicalModel.Factor f in model.factors) { foreach (int i in f.neigborIndices) { if (i > largestVariableNum) { largestVariableNum = i; } } } // this is presented in true order, where 0 corresponds to var 0 int[] mapValueAssignment = new int[largestVariableNum + 1]; // this is kept in the order that the factor presents to us int[] highestValueAssignment = new int[bruteForce.neighborIndices.Length]; foreach (int[] assignment in bruteForce) { if (bruteForce.GetAssignmentValue(assignment) > bruteForce.GetAssignmentValue(highestValueAssignment)) { highestValueAssignment = assignment; for (int i = 0; i < assignment.Length; i++) { mapValueAssignment[bruteForce.neighborIndices[i]] = assignment[i]; } } } int[] forcedAssignments = new int[largestVariableNum + 1]; for (int i_1 = 0; i_1 < mapValueAssignment.Length; i_1++) { if (model.GetVariableMetaDataByReference(i_1).Contains(CliqueTree.VariableObservedValue)) { mapValueAssignment[i_1] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(i_1)[CliqueTree.VariableObservedValue]); forcedAssignments[i_1] = mapValueAssignment[i_1]; } } if (!Arrays.Equals(mapValueAssignment, map)) { System.Console.Error.WriteLine("---"); System.Console.Error.WriteLine("Relevant variables: " + Arrays.ToString(bruteForce.neighborIndices)); System.Console.Error.WriteLine("Var Sizes: " + Arrays.ToString(bruteForce.GetDimensions())); System.Console.Error.WriteLine("MAP: " + Arrays.ToString(map)); System.Console.Error.WriteLine("Brute force map: " + Arrays.ToString(mapValueAssignment)); System.Console.Error.WriteLine("Forced assignments: " + Arrays.ToString(forcedAssignments)); } foreach (int i_2 in bruteForce.neighborIndices) { // Only check defined variables NUnit.Framework.Assert.AreEqual(mapValueAssignment[i_2], map[i_2]); } }
private void CheckMarginalsAgainstBruteForce(GraphicalModel model, ConcatVector weights, CliqueTree inference) { CliqueTree.MarginalResult result = inference.CalculateMarginals(); double[][] marginals = result.marginals; ICollection <TableFactor> tableFactors = model.factors.Stream().Map(null).Collect(Collectors.ToSet()); System.Diagnostics.Debug.Assert((tableFactors.Count == model.factors.Count)); // this is the super slow but obviously correct way to get global marginals TableFactor bruteForce = null; foreach (TableFactor factor in tableFactors) { if (bruteForce == null) { bruteForce = factor; } else { bruteForce = bruteForce.Multiply(factor); } } if (bruteForce != null) { // observe out all variables that have been registered TableFactor observed = bruteForce; for (int i = 0; i < bruteForce.neighborIndices.Length; i++) { int n = bruteForce.neighborIndices[i]; if (model.GetVariableMetaDataByReference(n).Contains(CliqueTree.VariableObservedValue)) { int value = System.Convert.ToInt32(model.GetVariableMetaDataByReference(n)[CliqueTree.VariableObservedValue]); // Check that the marginals reflect the observation for (int j = 0; j < marginals[n].Length; j++) { NUnit.Framework.Assert.AreEqual(marginals[n][j], 1.0e-9, j == value ? 1.0 : 0.0); } if (observed.neighborIndices.Length > 1) { observed = observed.Observe(n, value); } else { // If we've observed everything, then just quit return; } } } bruteForce = observed; // Spot check each of the marginals in the brute force calculation double[][] bruteMarginals = bruteForce.GetSummedMarginals(); int index = 0; foreach (int i_1 in bruteForce.neighborIndices) { bool isEqual = true; double[] brute = bruteMarginals[index]; index++; System.Diagnostics.Debug.Assert((brute != null)); System.Diagnostics.Debug.Assert((marginals[i_1] != null)); for (int j = 0; j < brute.Length; j++) { if (double.IsNaN(brute[j])) { isEqual = false; break; } if (Math.Abs(brute[j] - marginals[i_1][j]) > 3.0e-2) { isEqual = false; break; } } if (!isEqual) { System.Console.Error.WriteLine("Arrays not equal! Variable " + i_1); System.Console.Error.WriteLine("\tGold: " + Arrays.ToString(brute)); System.Console.Error.WriteLine("\tResult: " + Arrays.ToString(marginals[i_1])); } Assert.AssertArrayEquals(marginals[i_1], 3.0e-2, brute); } // Spot check the partition function double goldPartitionFunction = bruteForce.ValueSum(); // Correct to within 3% NUnit.Framework.Assert.AreEqual(result.partitionFunction, goldPartitionFunction * 3.0e-2, goldPartitionFunction); // Check the joint marginals foreach (GraphicalModel.Factor f in model.factors) { NUnit.Framework.Assert.IsTrue(result.jointMarginals.Contains(f)); TableFactor bruteForceJointMarginal = bruteForce; foreach (int n in bruteForce.neighborIndices) { foreach (int i_2 in f.neigborIndices) { if (i_2 == n) { goto outer_continue; } } if (bruteForceJointMarginal.neighborIndices.Length > 1) { bruteForceJointMarginal = bruteForceJointMarginal.SumOut(n); } else { int[] fixedAssignment = new int[f.neigborIndices.Length]; for (int i_3 = 0; i_3 < fixedAssignment.Length; i_3++) { fixedAssignment[i_3] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(f.neigborIndices[i_3])[CliqueTree.VariableObservedValue]); } foreach (int[] assn in result.jointMarginals[f]) { if (Arrays.Equals(assn, fixedAssignment)) { NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assn), 1.0e-7, 1.0); } else { if (result.jointMarginals[f].GetAssignmentValue(assn) != 0) { TableFactor j = result.jointMarginals[f]; foreach (int[] assignment in j) { System.Console.Error.WriteLine(Arrays.ToString(assignment) + ": " + j.GetAssignmentValue(assignment)); } } NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assn), 1.0e-7, 0.0); } } goto marginals_continue; } } outer_break :; // Find the correspondence between the brute force joint marginal, which may be missing variables // because they were observed out of the table, and the output joint marginals, which are always an exact // match for the original factor int[] backPointers = new int[f.neigborIndices.Length]; int[] observedValue = new int[f.neigborIndices.Length]; for (int i_4 = 0; i_4 < backPointers.Length; i_4++) { if (model.GetVariableMetaDataByReference(f.neigborIndices[i_4]).Contains(CliqueTree.VariableObservedValue)) { observedValue[i_4] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(f.neigborIndices[i_4])[CliqueTree.VariableObservedValue]); backPointers[i_4] = -1; } else { observedValue[i_4] = -1; backPointers[i_4] = -1; for (int j = 0; j < bruteForceJointMarginal.neighborIndices.Length; j++) { if (bruteForceJointMarginal.neighborIndices[j] == f.neigborIndices[i_4]) { backPointers[i_4] = j; } } System.Diagnostics.Debug.Assert((backPointers[i_4] != -1)); } } double sum = bruteForceJointMarginal.ValueSum(); if (sum == 0.0) { sum = 1; } foreach (int[] assignment_1 in result.jointMarginals[f]) { int[] bruteForceMarginalAssignment = new int[bruteForceJointMarginal.neighborIndices.Length]; for (int i_2 = 0; i_2 < assignment_1.Length; i_2++) { if (backPointers[i_2] != -1) { bruteForceMarginalAssignment[backPointers[i_2]] = assignment_1[i_2]; } else { // Make sure all assignments that don't square with observations get 0 weight System.Diagnostics.Debug.Assert((observedValue[i_2] != -1)); if (assignment_1[i_2] != observedValue[i_2]) { if (result.jointMarginals[f].GetAssignmentValue(assignment_1) != 0) { System.Console.Error.WriteLine("Joint marginals: " + Arrays.ToString(result.jointMarginals[f].neighborIndices)); System.Console.Error.WriteLine("Assignment: " + Arrays.ToString(assignment_1)); System.Console.Error.WriteLine("Observed Value: " + Arrays.ToString(observedValue)); foreach (int[] assn in result.jointMarginals[f]) { System.Console.Error.WriteLine("\t" + Arrays.ToString(assn) + ":" + result.jointMarginals[f].GetAssignmentValue(assn)); } } NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assignment_1), 1.0e-7, 0.0); goto outer_continue; } } } NUnit.Framework.Assert.AreEqual(result.jointMarginals[f].GetAssignmentValue(assignment_1), 1.0e-3, bruteForceJointMarginal.GetAssignmentValue(bruteForceMarginalAssignment) / sum); } outer_break :; } marginals_break :; } else { foreach (double[] marginal in marginals) { foreach (double d in marginal) { NUnit.Framework.Assert.AreEqual(d, 3.0e-2, 1.0 / marginal.Length); } } } }