Beispiel #1
0
		public virtual void TestMaxOut(TableFactor factor, int marginalize)
		{
			if (!Arrays.Stream(factor.neighborIndices).Boxed().Collect(Collectors.ToSet()).Contains(marginalize))
			{
				return;
			}
			if (factor.neighborIndices.Length <= 1)
			{
				return;
			}
			TableFactor maxedOut = factor.MaxOut((int)marginalize);
			NUnit.Framework.Assert.AreEqual(factor.neighborIndices.Length - 1, maxedOut.neighborIndices.Length);
			NUnit.Framework.Assert.IsTrue(!Arrays.Stream(maxedOut.neighborIndices).Boxed().Collect(Collectors.ToSet()).Contains(marginalize));
			foreach (int[] assignment in factor)
			{
				NUnit.Framework.Assert.IsTrue(factor.GetAssignmentValue(assignment) >= double.NegativeInfinity);
				NUnit.Framework.Assert.IsTrue(factor.GetAssignmentValue(assignment) <= maxedOut.GetAssignmentValue(SubsetAssignment(assignment, (TableFactor)factor, maxedOut)));
			}
			IDictionary<IList<int>, IList<int[]>> subsetToSuperset = SubsetToSupersetAssignments((TableFactor)factor, maxedOut);
			foreach (IList<int> subsetAssignmentList in subsetToSuperset.Keys)
			{
				double max = double.NegativeInfinity;
				foreach (int[] supersetAssignment in subsetToSuperset[subsetAssignmentList])
				{
					max = Math.Max(max, factor.GetAssignmentValue(supersetAssignment));
				}
				int[] subsetAssignment = new int[subsetAssignmentList.Count];
				for (int i = 0; i < subsetAssignment.Length; i++)
				{
					subsetAssignment[i] = subsetAssignmentList[i];
				}
				NUnit.Framework.Assert.AreEqual(maxedOut.GetAssignmentValue(subsetAssignment), 1.0e-5, max);
			}
		}
Beispiel #2
0
        /// <summary>This is a key step in message passing.</summary>
        /// <remarks>
        /// This is a key step in message passing. When we are calculating a message, we want to marginalize out all variables
        /// not relevant to the recipient of the message. This function does that.
        /// </remarks>
        /// <param name="message">the message to marginalize</param>
        /// <param name="relevant">the variables that are relevant</param>
        /// <param name="marginalize">whether to use sum of max marginalization, for marginal or MAP inference</param>
        /// <returns>the marginalized message</returns>
        private static TableFactor MarginalizeMessage(TableFactor message, int[] relevant, CliqueTree.MarginalizationMethod marginalize)
        {
            TableFactor result = message;

            foreach (int i in message.neighborIndices)
            {
                bool contains = false;
                foreach (int j in relevant)
                {
                    if (i == j)
                    {
                        contains = true;
                        break;
                    }
                }
                if (!contains)
                {
                    switch (marginalize)
                    {
                    case CliqueTree.MarginalizationMethod.Sum:
                    {
                        result = result.SumOut(i);
                        break;
                    }

                    case CliqueTree.MarginalizationMethod.Max:
                    {
                        result = result.MaxOut(i);
                        break;
                    }
                    }
                }
            }
            return(result);
        }