public virtual void TestMaxOut(TableFactor factor, int marginalize) { if (!Arrays.Stream(factor.neighborIndices).Boxed().Collect(Collectors.ToSet()).Contains(marginalize)) { return; } if (factor.neighborIndices.Length <= 1) { return; } TableFactor maxedOut = factor.MaxOut((int)marginalize); NUnit.Framework.Assert.AreEqual(factor.neighborIndices.Length - 1, maxedOut.neighborIndices.Length); NUnit.Framework.Assert.IsTrue(!Arrays.Stream(maxedOut.neighborIndices).Boxed().Collect(Collectors.ToSet()).Contains(marginalize)); foreach (int[] assignment in factor) { NUnit.Framework.Assert.IsTrue(factor.GetAssignmentValue(assignment) >= double.NegativeInfinity); NUnit.Framework.Assert.IsTrue(factor.GetAssignmentValue(assignment) <= maxedOut.GetAssignmentValue(SubsetAssignment(assignment, (TableFactor)factor, maxedOut))); } IDictionary<IList<int>, IList<int[]>> subsetToSuperset = SubsetToSupersetAssignments((TableFactor)factor, maxedOut); foreach (IList<int> subsetAssignmentList in subsetToSuperset.Keys) { double max = double.NegativeInfinity; foreach (int[] supersetAssignment in subsetToSuperset[subsetAssignmentList]) { max = Math.Max(max, factor.GetAssignmentValue(supersetAssignment)); } int[] subsetAssignment = new int[subsetAssignmentList.Count]; for (int i = 0; i < subsetAssignment.Length; i++) { subsetAssignment[i] = subsetAssignmentList[i]; } NUnit.Framework.Assert.AreEqual(maxedOut.GetAssignmentValue(subsetAssignment), 1.0e-5, max); } }
/// <summary>This is a key step in message passing.</summary> /// <remarks> /// This is a key step in message passing. When we are calculating a message, we want to marginalize out all variables /// not relevant to the recipient of the message. This function does that. /// </remarks> /// <param name="message">the message to marginalize</param> /// <param name="relevant">the variables that are relevant</param> /// <param name="marginalize">whether to use sum of max marginalization, for marginal or MAP inference</param> /// <returns>the marginalized message</returns> private static TableFactor MarginalizeMessage(TableFactor message, int[] relevant, CliqueTree.MarginalizationMethod marginalize) { TableFactor result = message; foreach (int i in message.neighborIndices) { bool contains = false; foreach (int j in relevant) { if (i == j) { contains = true; break; } } if (!contains) { switch (marginalize) { case CliqueTree.MarginalizationMethod.Sum: { result = result.SumOut(i); break; } case CliqueTree.MarginalizationMethod.Max: { result = result.MaxOut(i); break; } } } } return(result); }