示例#1
0
		public virtual void TestGetMaxedMarginals(TableFactor factor, int marginalizeTo)
		{
			if (!Arrays.Stream(factor.neighborIndices).Boxed().Collect(Collectors.ToSet()).Contains(marginalizeTo))
			{
				return;
			}
			int indexOfVariable = -1;
			for (int i = 0; i < factor.neighborIndices.Length; i++)
			{
				if (factor.neighborIndices[i] == marginalizeTo)
				{
					indexOfVariable = i;
					break;
				}
			}
			Assume.AssumeTrue(indexOfVariable > -1);
			double[] gold = new double[factor.GetDimensions()[indexOfVariable]];
			for (int i_1 = 0; i_1 < gold.Length; i_1++)
			{
				gold[i_1] = double.NegativeInfinity;
			}
			foreach (int[] assignment in factor)
			{
				gold[assignment[indexOfVariable]] = Math.Max(gold[assignment[indexOfVariable]], factor.GetAssignmentValue(assignment));
			}
			Normalize(gold);
			Assert.AssertArrayEquals(factor.GetMaxedMarginals()[indexOfVariable], 1.0e-5, gold);
		}
        public virtual void CheckMAPAgainstBruteForce(GraphicalModel model, ConcatVector weights, CliqueTree inference)
        {
            int[] map = inference.CalculateMAP();
            ICollection <TableFactor> tableFactors = model.factors.Stream().Map(null).Collect(Collectors.ToSet());
            // this is the super slow but obviously correct way to get global marginals
            TableFactor bruteForce = null;

            foreach (TableFactor factor in tableFactors)
            {
                if (bruteForce == null)
                {
                    bruteForce = factor;
                }
                else
                {
                    bruteForce = bruteForce.Multiply(factor);
                }
            }
            System.Diagnostics.Debug.Assert((bruteForce != null));
            // observe out all variables that have been registered
            TableFactor observed = bruteForce;

            foreach (int n in bruteForce.neighborIndices)
            {
                if (model.GetVariableMetaDataByReference(n).Contains(CliqueTree.VariableObservedValue))
                {
                    int value = System.Convert.ToInt32(model.GetVariableMetaDataByReference(n)[CliqueTree.VariableObservedValue]);
                    if (observed.neighborIndices.Length > 1)
                    {
                        observed = observed.Observe(n, value);
                    }
                    else
                    {
                        // If we've observed everything, then just quit
                        return;
                    }
                }
            }
            bruteForce = observed;
            int largestVariableNum = 0;

            foreach (GraphicalModel.Factor f in model.factors)
            {
                foreach (int i in f.neigborIndices)
                {
                    if (i > largestVariableNum)
                    {
                        largestVariableNum = i;
                    }
                }
            }
            // this is presented in true order, where 0 corresponds to var 0
            int[] mapValueAssignment = new int[largestVariableNum + 1];
            // this is kept in the order that the factor presents to us
            int[] highestValueAssignment = new int[bruteForce.neighborIndices.Length];
            foreach (int[] assignment in bruteForce)
            {
                if (bruteForce.GetAssignmentValue(assignment) > bruteForce.GetAssignmentValue(highestValueAssignment))
                {
                    highestValueAssignment = assignment;
                    for (int i = 0; i < assignment.Length; i++)
                    {
                        mapValueAssignment[bruteForce.neighborIndices[i]] = assignment[i];
                    }
                }
            }
            int[] forcedAssignments = new int[largestVariableNum + 1];
            for (int i_1 = 0; i_1 < mapValueAssignment.Length; i_1++)
            {
                if (model.GetVariableMetaDataByReference(i_1).Contains(CliqueTree.VariableObservedValue))
                {
                    mapValueAssignment[i_1] = System.Convert.ToInt32(model.GetVariableMetaDataByReference(i_1)[CliqueTree.VariableObservedValue]);
                    forcedAssignments[i_1]  = mapValueAssignment[i_1];
                }
            }
            if (!Arrays.Equals(mapValueAssignment, map))
            {
                System.Console.Error.WriteLine("---");
                System.Console.Error.WriteLine("Relevant variables: " + Arrays.ToString(bruteForce.neighborIndices));
                System.Console.Error.WriteLine("Var Sizes: " + Arrays.ToString(bruteForce.GetDimensions()));
                System.Console.Error.WriteLine("MAP: " + Arrays.ToString(map));
                System.Console.Error.WriteLine("Brute force map: " + Arrays.ToString(mapValueAssignment));
                System.Console.Error.WriteLine("Forced assignments: " + Arrays.ToString(forcedAssignments));
            }
            foreach (int i_2 in bruteForce.neighborIndices)
            {
                // Only check defined variables
                NUnit.Framework.Assert.AreEqual(mapValueAssignment[i_2], map[i_2]);
            }
        }