/// <summary>
        /// Evidence message for EP
        /// </summary>
        /// <param name="isPositive">Incoming message from 'isPositive'.</param>
        /// <param name="x">Incoming message from 'x'.</param>
        /// <returns>Logarithm of the factor's average value across the given argument distributions</returns>
        /// <remarks><para>
        /// The formula for the result is <c>log(sum_(isPositive,x) p(isPositive,x) factor(isPositive,x))</c>.
        /// </para></remarks>
        public static double LogAverageFactor(Bernoulli isPositive, Gaussian x)
        {
            Bernoulli to_isPositive = IsPositiveAverageConditional(x);

            return(isPositive.GetLogAverageOf(to_isPositive));

#if false
            // Z = p(b=T) p(x > 0) + p(b=F) p(x <= 0)
            //   = p(b=F) + (p(b=T) - p(b=F)) p(x > 0)
            if (x.IsPointMass)
            {
                return(Factor.IsPositive(x.Point) ? isPositive.GetLogProbTrue() : isPositive.GetLogProbFalse());
            }
            else if (x.IsUniform())
            {
                return(Bernoulli.LogProbEqual(isPositive.LogOdds, 0.0));
            }
            else
            {
                // m/sqrt(v) = (m/v)/sqrt(1/v)
                double z = x.MeanTimesPrecision / Math.Sqrt(x.Precision);
                if (isPositive.IsPointMass)
                {
                    return(isPositive.Point ? MMath.NormalCdfLn(z) : MMath.NormalCdfLn(-z));
                }
                else
                {
                    return(MMath.LogSumExp(isPositive.GetLogProbTrue() + MMath.NormalCdfLn(z), isPositive.GetLogProbFalse() + MMath.NormalCdfLn(-z)));
                }
            }
#endif
        }
        /// <summary>
        /// VMP message to 'b'
        /// </summary>
        /// <param name="isGreaterThan">Incoming message from 'isGreaterThan'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
        /// <param name="a">Incoming message from 'a'.</param>
        /// <param name="result">Modified to contain the outgoing message</param>
        /// <returns><paramref name="result"/></returns>
        /// <remarks><para>
        /// The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except 'b'.
        /// Because the factor is deterministic, 'isGreaterThan' is integrated out before taking the logarithm.
        /// The formula is <c>exp(sum_(a) p(a) log(sum_isGreaterThan p(isGreaterThan) factor(isGreaterThan,a,b)))</c>.
        /// </para></remarks>
        /// <exception cref="ImproperMessageException"><paramref name="isGreaterThan"/> is not a proper distribution</exception>
        public static Discrete BAverageLogarithm([SkipIfUniform] Bernoulli isGreaterThan, Discrete a, Discrete result)
        {
            if (a.IsPointMass)
            {
                return(BAverageLogarithm(isGreaterThan, a.Point, result));
            }
            if (isGreaterThan.IsPointMass)
            {
                return(BAverageLogarithm(isGreaterThan.Point, a, result));
            }
            // f(a,b) = p(c=1) I(a > b) + p(c=0) I(a <= b)
            // message to b = exp(sum_a q(a) log f(a,b))
            Vector bProbs       = result.GetWorkspace();
            double logProbTrue  = isGreaterThan.GetLogProbTrue();
            double logProbFalse = isGreaterThan.GetLogProbFalse();

            for (int j = 0; j < bProbs.Count; j++)
            {
                double sum = 0.0;
                int    i   = 0;
                for (; (i <= j) && (i < a.Dimension); i++)
                {
                    sum += logProbFalse * a[i];
                }
                for (; i < a.Dimension; i++)
                {
                    sum += logProbTrue * a[i];
                }
                bProbs[j] = Math.Exp(sum);
            }
            result.SetProbs(bProbs);
            return(result);
        }
        /// <summary>
        /// VMP message to 'a'
        /// </summary>
        /// <param name="isGreaterThan">Incoming message from 'isGreaterThan'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
        /// <param name="b">Incoming message from 'b'.</param>
        /// <param name="result">Modified to contain the outgoing message</param>
        /// <returns><paramref name="result"/></returns>
        /// <remarks><para>
        /// The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except 'a'.
        /// Because the factor is deterministic, 'isGreaterThan' is integrated out before taking the logarithm.
        /// The formula is <c>exp(sum_(b) p(b) log(sum_isGreaterThan p(isGreaterThan) factor(isGreaterThan,a,b)))</c>.
        /// </para></remarks>
        /// <exception cref="ImproperMessageException"><paramref name="isGreaterThan"/> is not a proper distribution</exception>
        public static Discrete AAverageLogarithm([SkipIfUniform] Bernoulli isGreaterThan, Discrete b, Discrete result)
        {
            if (b.IsPointMass)
            {
                return(AAverageLogarithm(isGreaterThan, b.Point, result));
            }
            if (isGreaterThan.IsPointMass)
            {
                return(AAverageLogarithm(isGreaterThan.Point, b, result));
            }
            // f(a,b) = p(c=1) I(a > b) + p(c=0) I(a <= b)
            // message to a = exp(sum_b q(b) log f(a,b))
            Vector aProbs       = result.GetWorkspace();
            double logProbTrue  = isGreaterThan.GetLogProbTrue();
            double logProbFalse = isGreaterThan.GetLogProbFalse();

            for (int i = 0; i < aProbs.Count; i++)
            {
                double sum = 0.0;
                int    j   = 0;
                for (; (j < i) && (j < b.Dimension); j++)
                {
                    sum += logProbTrue * b[j];
                }
                for (; j < b.Dimension; j++)
                {
                    sum += logProbFalse * b[j];
                }
                aProbs[i] = Math.Exp(sum);
            }
            result.SetProbs(aProbs);
            return(result);
        }
 // result.LogOdds = [log p(b=true), log p(b=false)]
 /// <summary>
 /// EP message to 'cases'.
 /// </summary>
 /// <param name="b">Incoming message from 'b'.</param>
 /// <param name="result">Modified to contain the outgoing message.</param>
 /// <returns><paramref name="result"/></returns>
 /// <remarks><para>
 /// The outgoing message is the integral of the factor times incoming messages, over all arguments except 'cases'.
 /// The formula is <c>int f(cases,x) q(x) dx</c> where <c>x = (b)</c>.
 /// </para></remarks>
 public static BernoulliList CasesAverageConditional <BernoulliList>(Bernoulli b, BernoulliList result)
     where BernoulliList : IList <Bernoulli>
 {
     if (result.Count != 2)
     {
         throw new ArgumentException("result.Count != 2");
     }
     result[0] = Bernoulli.FromLogOdds(b.GetLogProbTrue());
     result[1] = Bernoulli.FromLogOdds(b.GetLogProbFalse());
     return(result);
 }
		/// <summary>
		/// Evidence message for EP
		/// </summary>
		/// <param name="isPositive">Incoming message from 'isPositive'.</param>
		/// <param name="x">Incoming message from 'x'.</param>
		/// <returns>Logarithm of the factor's average value across the given argument distributions</returns>
		/// <remarks><para>
		/// The formula for the result is <c>log(sum_(isPositive,x) p(isPositive,x) factor(isPositive,x))</c>.
		/// </para></remarks>
		public static double LogAverageFactor(Bernoulli isPositive, Gaussian x)
		{
			Bernoulli to_isPositive = IsPositiveAverageConditional(x);
			return isPositive.GetLogAverageOf(to_isPositive);
#if false
			// Z = p(b=T) p(x > 0) + p(b=F) p(x <= 0)
			//   = p(b=F) + (p(b=T) - p(b=F)) p(x > 0)
			if (x.IsPointMass) {
				return Factor.IsPositive(x.Point) ? isPositive.GetLogProbTrue() : isPositive.GetLogProbFalse();
			} else if(x.IsUniform()) {
				return Bernoulli.LogProbEqual(isPositive.LogOdds,0.0);
			} else {
				// m/sqrt(v) = (m/v)/sqrt(1/v)
				double z = x.MeanTimesPrecision / Math.Sqrt(x.Precision);
				if (isPositive.IsPointMass) {
					return isPositive.Point ? MMath.NormalCdfLn(z) : MMath.NormalCdfLn(-z);
				} else {
					return MMath.LogSumExp(isPositive.GetLogProbTrue() + MMath.NormalCdfLn(z), isPositive.GetLogProbFalse() + MMath.NormalCdfLn(-z));
				}
			}
#endif
		}
 /// <summary>
 /// Evidence message for EP
 /// </summary>
 /// <param name="case0">Incoming message from 'case0'.</param>
 /// <param name="case1">Incoming message from 'case1'.</param>
 /// <param name="b">Incoming message from 'b'.</param>
 /// <returns>Logarithm of the factor's contribution the EP model evidence</returns>
 /// <remarks><para>
 /// The formula for the result is <c>log(sum_(case0,case1,b) p(case0,case1,b) factor(b,case0,case1))</c>.
 /// Adding up these values across all factors and variables gives the log-evidence estimate for EP.
 /// </para></remarks>
 public static double LogEvidenceRatio(Bernoulli case0, Bernoulli case1, Bernoulli b)
 {
     // result = log (p(data|b=true) p(b=true) + p(data|b=false) p(b=false))
     // where cases[0].LogOdds = log p(data|b=true)
     //       cases[1].LogOdds = log p(data|b=false)
     if (b.IsPointMass)
     {
         return(b.Point ? case0.LogOdds : case1.LogOdds);
     }
     else
     {
         return(MMath.LogSumExp(case0.LogOdds + b.GetLogProbTrue(), case1.LogOdds + b.GetLogProbFalse()));
     }
 }
示例#7
0
        /// <summary>EP message to <c>logOdds</c>.</summary>
        /// <param name="sample">Incoming message from <c>sample</c>.</param>
        /// <param name="logOdds">Incoming message from <c>logOdds</c>. Must be a proper distribution. If uniform, the result will be uniform.</param>
        /// <returns>The outgoing EP message to the <c>logOdds</c> argument.</returns>
        /// <remarks>
        ///   <para>The outgoing message is a distribution matching the moments of <c>logOdds</c> as the random arguments are varied. The formula is <c>proj[p(logOdds) sum_(sample) p(sample) factor(sample,logOdds)]/p(logOdds)</c>.</para>
        /// </remarks>
        /// <exception cref="ImproperMessageException">
        ///   <paramref name="logOdds" /> is not a proper distribution.</exception>
        public static Gaussian LogOddsAverageConditional(Bernoulli sample, [SkipIfUniform] Gaussian logOdds)
        {
            Gaussian toLogOddsT = LogOddsAverageConditional(true, logOdds);
            double   logWeightT = LogAverageFactor(true, logOdds) + sample.GetLogProbTrue();
            Gaussian toLogOddsF = LogOddsAverageConditional(false, logOdds);
            double   logWeightF = LogAverageFactor(false, logOdds) + sample.GetLogProbFalse();
            double   maxWeight  = Math.Max(logWeightT, logWeightF);

            logWeightT -= maxWeight;
            logWeightF -= maxWeight;
            Gaussian result = new Gaussian();

            result.SetToSum(Math.Exp(logWeightT), toLogOddsT * logOdds, Math.Exp(logWeightF), toLogOddsF * logOdds);
            result.SetToRatio(result, logOdds, ForceProper);
            return(result);
        }
 /// <summary>
 /// Evidence message for EP.
 /// </summary>
 /// <param name="cases">Incoming message from 'cases'.</param>
 /// <param name="b">Incoming message from 'b'.</param>
 public static double LogEvidenceRatio(IList <Bernoulli> cases, Bernoulli b)
 {
     // result = log (p(data|b=true) p(b=true) + p(data|b=false) p(b=false))
     //          log (p(data|b=true) p(b=true) + p(data|b=false) (1-p(b=true))
     //          log ((p(data|b=true) - p(data|b=false)) p(b=true) + p(data|b=false))
     //          log ((p(data|b=true)/p(data|b=false) - 1) p(b=true) + 1) + log p(data|b=false)
     // where cases[0].LogOdds = log p(data|b=true)
     //       cases[1].LogOdds = log p(data|b=false)
     if (b.IsPointMass)
     {
         return(b.Point ? cases[0].LogOdds : cases[1].LogOdds);
     }
     //else return MMath.LogSumExp(cases[0].LogOdds + b.GetLogProbTrue(), cases[1].LogOdds + b.GetLogProbFalse());
     else
     {
         // the common case is when cases[0].LogOdds == cases[1].LogOdds.  we must not introduce rounding error in that case.
         if (cases[0].LogOdds >= cases[1].LogOdds)
         {
             if (Double.IsNegativeInfinity(cases[1].LogOdds))
             {
                 return(cases[0].LogOdds + b.GetLogProbTrue());
             }
             else
             {
                 return(cases[1].LogOdds + MMath.Log1Plus(b.GetProbTrue() * MMath.ExpMinus1(cases[0].LogOdds - cases[1].LogOdds)));
             }
         }
         else
         {
             if (Double.IsNegativeInfinity(cases[0].LogOdds))
             {
                 return(cases[1].LogOdds + b.GetLogProbFalse());
             }
             else
             {
                 return(cases[0].LogOdds + MMath.Log1Plus(b.GetProbFalse() * MMath.ExpMinus1(cases[1].LogOdds - cases[0].LogOdds)));
             }
         }
     }
 }
示例#9
0
        /// <summary>
        /// Tests EP and BP gate enter ops for Bernoulli random variable for correctness given message parameters.
        /// </summary>
        /// <param name="valueProbTrue">Probability of being true for the variable entering the gate.</param>
        /// <param name="enterOneProbTrue">Probability of being true for the variable approximation inside the gate when the selector is true.</param>
        /// <param name="selectorProbTrue">Probability of being true for the selector variable.</param>
        private void DoBernoulliEnterTest(double valueProbTrue, double enterOneProbTrue, double selectorProbTrue)
        {
            var value                   = new Bernoulli(valueProbTrue);
            var enterOne                = new Bernoulli(enterOneProbTrue);
            var selector                = new Bernoulli(selectorProbTrue);
            var selectorInverse         = new Bernoulli(selector.GetProbFalse());
            var discreteSelector        = new Discrete(selector.GetProbTrue(), selector.GetProbFalse());
            var discreteSelectorInverse = new Discrete(selector.GetProbFalse(), selector.GetProbTrue());
            var cases                   = new[] { Bernoulli.FromLogOdds(selector.GetLogProbTrue()), Bernoulli.FromLogOdds(selector.GetLogProbFalse()) };

            // Compute expected message
            double logShift           = enterOne.GetLogNormalizer() + value.GetLogNormalizer() - (value * enterOne).GetLogNormalizer();
            double expectedProbTrue   = selector.GetProbFalse() + (selector.GetProbTrue() * enterOne.GetProbTrue() * Math.Exp(logShift));
            double expectedProbFalse  = selector.GetProbFalse() + (selector.GetProbTrue() * enterOne.GetProbFalse() * Math.Exp(logShift));
            double expectedNormalizer = expectedProbTrue + expectedProbFalse;

            expectedProbTrue /= expectedNormalizer;

            Bernoulli value1, value2;

            // Enter partial (bernoulli selector, first case)
            value1 = GateEnterPartialOp <bool> .ValueAverageConditional(
                new[] { enterOne }, selector, value, new[] { 0 }, new Bernoulli());

            value2 = BeliefPropagationGateEnterPartialOp.ValueAverageConditional(
                new[] { enterOne }, selector, value, new[] { 0 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial (bernoulli selector, second case)
            value1 = GateEnterPartialOp <bool> .ValueAverageConditional(
                new[] { enterOne }, selectorInverse, value, new[] { 1 }, new Bernoulli());

            value2 = BeliefPropagationGateEnterPartialOp.ValueAverageConditional(
                new[] { enterOne }, selectorInverse, value, new[] { 1 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial (bernoulli selector, both cases)
            value1 = GateEnterPartialOp <bool> .ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, selector, value, new[] { 0, 1 }, new Bernoulli());

            value2 = BeliefPropagationGateEnterPartialOp.ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, selector, value, new[] { 0, 1 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial (discrete selector, first case)
            value1 = GateEnterPartialOp <bool> .ValueAverageConditional(
                new[] { enterOne }, discreteSelector, value, new[] { 0 }, new Bernoulli());

            value2 = BeliefPropagationGateEnterPartialOp.ValueAverageConditional(
                new[] { enterOne }, discreteSelector, value, new[] { 0 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial (discrete selector, second case)
            value1 = GateEnterPartialOp <bool> .ValueAverageConditional(
                new[] { enterOne }, discreteSelectorInverse, value, new[] { 1 }, new Bernoulli());

            value2 = BeliefPropagationGateEnterPartialOp.ValueAverageConditional(
                new[] { enterOne }, discreteSelectorInverse, value, new[] { 1 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial (discrete selector, both cases)
            value1 = GateEnterPartialOp <bool> .ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, discreteSelector, value, new[] { 0, 1 }, new Bernoulli());

            value2 = BeliefPropagationGateEnterPartialOp.ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, discreteSelector, value, new[] { 0, 1 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter one (discrete selector, first case)
            value1 = GateEnterOneOp <bool> .ValueAverageConditional(
                enterOne, discreteSelector, value, 0, new Bernoulli());

            value2 = BeliefPropagationGateEnterOneOp.ValueAverageConditional(
                enterOne, discreteSelector, value, 0, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter one (discrete selector, second case)
            value1 = GateEnterOneOp <bool> .ValueAverageConditional(
                enterOne, discreteSelectorInverse, value, 1, new Bernoulli());

            value2 = BeliefPropagationGateEnterOneOp.ValueAverageConditional(
                enterOne, discreteSelectorInverse, value, 1, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial two  (first case)
            value1 = GateEnterPartialTwoOp.ValueAverageConditional(
                new[] { enterOne }, cases[0], cases[1], value, new[] { 0 }, new Bernoulli());
            value2 = BeliefPropagationGateEnterPartialTwoOp.ValueAverageConditional(
                new[] { enterOne }, cases[0], cases[1], value, new[] { 0 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial two  (second case)
            value1 = GateEnterPartialTwoOp.ValueAverageConditional(
                new[] { enterOne }, cases[1], cases[0], value, new[] { 1 }, new Bernoulli());
            value2 = BeliefPropagationGateEnterPartialTwoOp.ValueAverageConditional(
                new[] { enterOne }, cases[1], cases[0], value, new[] { 1 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter partial two (both cases)
            value1 = GateEnterPartialTwoOp.ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, cases[0], cases[1], value, new[] { 0, 1 }, new Bernoulli());
            value2 = BeliefPropagationGateEnterPartialTwoOp.ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, cases[0], cases[1], value, new[] { 0, 1 }, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);

            // Enter (discrete selector)
            value1 = GateEnterOp <bool> .ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, discreteSelector, value, new Bernoulli());

            value2 = BeliefPropagationGateEnterOp.ValueAverageConditional(
                new[] { enterOne, Bernoulli.Uniform() }, discreteSelector, value, new Bernoulli());
            Assert.Equal(expectedProbTrue, value1.GetProbTrue(), 1e-4);
            Assert.Equal(expectedProbTrue, value2.GetProbTrue(), 1e-4);
        }
示例#10
0
        /// <summary>EP message to <c>value</c>.</summary>
        /// <param name="enterPartial">Incoming message from <c>enterPartial</c>. Must be a proper distribution. If any element is uniform, the result will be uniform.</param>
        /// <param name="selector">Incoming message from <c>selector</c>. Must be a proper distribution. If uniform, the result will be uniform.</param>
        /// <param name="value">Incoming message from <c>value</c>.</param>
        /// <param name="indices">Constant value for <c>indices</c>.</param>
        /// <param name="result">Modified to contain the outgoing message.</param>
        /// <returns>
        ///   <paramref name="result" />
        /// </returns>
        /// <remarks>
        ///   <para>The outgoing message is a distribution matching the moments of <c>value</c> as the random arguments are varied. The formula is <c>proj[p(value) sum_(enterPartial,selector) p(enterPartial,selector) factor(enterPartial,selector,value,indices)]/p(value)</c>.</para>
        /// </remarks>
        /// <exception cref="ImproperMessageException">
        ///   <paramref name="enterPartial" /> is not a proper distribution.</exception>
        /// <exception cref="ImproperMessageException">
        ///   <paramref name="selector" /> is not a proper distribution.</exception>
        /// <typeparam name="TDist">The type of the distribution over the variable entering the gate.</typeparam>
        public static TDist ValueAverageConditional <TDist>(
            [SkipIfUniform] IList <TDist> enterPartial, [SkipIfUniform] Bernoulli selector, TDist value, int[] indices, TDist result)
            where TDist : ICloneable, SettableToUniform, SettableTo <TDist>, SettableToWeightedSum <TDist>, CanGetLogAverageOf <TDist>, CanGetLogNormalizer
        {
            if (enterPartial == null)
            {
                throw new ArgumentNullException("enterPartial");
            }

            if (indices == null)
            {
                throw new ArgumentNullException("indices");
            }

            if (indices.Length != enterPartial.Count)
            {
                throw new ArgumentException("indices.Length != enterPartial.Count");
            }

            if (2 < enterPartial.Count)
            {
                throw new ArgumentException("enterPartial.Count should be 2 or 1");
            }

            if (indices.Length == 0)
            {
                throw new ArgumentException("indices.Length == 0");
            }

            // TODO: use pre-allocated buffers
            double logProbSum   = (indices[0] == 0) ? selector.GetLogProbTrue() : selector.GetLogProbFalse();
            double logWeightSum = logProbSum;

            if (!double.IsNegativeInfinity(logProbSum))
            {
                logWeightSum -= enterPartial[0].GetLogAverageOf(value);
                result.SetTo(enterPartial[0]);
            }

            if (indices.Length > 1)
            {
                for (int i = 1; i < indices.Length; i++)
                {
                    double logProb = (indices[i] == 0) ? selector.GetLogProbTrue() : selector.GetLogProbFalse();
                    logProbSum += logProb;
                    double shift = Math.Max(logWeightSum, logProb);

                    // Avoid (-Infinity) - (-Infinity)
                    if (double.IsNegativeInfinity(shift))
                    {
                        if (i == 1)
                        {
                            throw new AllZeroException();
                        }

                        // Do nothing
                    }
                    else
                    {
                        double logWeightShifted = logProb - shift;
                        if (!double.IsNegativeInfinity(logWeightShifted))
                        {
                            logWeightShifted -= enterPartial[i].GetLogAverageOf(value);
                            result.SetToSum(Math.Exp(logWeightSum - shift), result, Math.Exp(logWeightShifted), enterPartial[i]);
                            logWeightSum = MMath.LogSumExp(logWeightSum, logWeightShifted + shift);
                        }
                    }
                }
            }

            if (indices.Length < 2)
            {
                double logProb = MMath.Log1MinusExp(logProbSum);
                double shift   = Math.Max(logWeightSum, logProb);
                if (double.IsNegativeInfinity(shift))
                {
                    throw new AllZeroException();
                }

                var uniform = (TDist)result.Clone();
                uniform.SetToUniform();
                double logWeight = logProb + uniform.GetLogNormalizer();
                result.SetToSum(Math.Exp(logWeightSum - shift), result, Math.Exp(logWeight - shift), uniform);
            }

            return(result);
        }
 /// <summary>
 /// EP message to 'case0'
 /// </summary>
 /// <param name="b">Incoming message from 'b'.</param>
 /// <returns>The outgoing EP message to the 'case0' argument</returns>
 /// <remarks><para>
 /// The outgoing message is a distribution matching the moments of 'case0' as the random arguments are varied.
 /// The formula is <c>proj[p(case0) sum_(b) p(b) factor(b,case0,case1)]/p(case0)</c>.
 /// </para></remarks>
 public static Bernoulli Case0AverageConditional(Bernoulli b)
 {
     return(Bernoulli.FromLogOdds(b.GetLogProbTrue()));
 }
示例#12
0
		/// <summary>
		/// EP message to 'logOdds'.
		/// </summary>
		/// <param name="sample">Incoming message from sample.</param>
		/// <param name="logOdds">Incoming message from 'logOdds'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
		/// <returns>The outgoing EP message to the 'logOdds' argument.</returns>
		/// <remarks><para>
		/// The outgoing message is the moment matched Gaussian approximation to the factor.
		/// </para></remarks>
		public static Gaussian LogOddsAverageConditional(Bernoulli sample, [SkipIfUniform] Gaussian logOdds)
		{
			Gaussian toLogOddsT = LogOddsAverageConditional(true, logOdds);
			double logWeightT = LogAverageFactor(true, logOdds) + sample.GetLogProbTrue();
			Gaussian toLogOddsF = LogOddsAverageConditional(false, logOdds);
			double logWeightF = LogAverageFactor(false, logOdds) + sample.GetLogProbFalse();
			double maxWeight = Math.Max(logWeightT, logWeightF);
			logWeightT -= maxWeight;
			logWeightF -= maxWeight;
			Gaussian result = new Gaussian();
			result.SetToSum(Math.Exp(logWeightT), toLogOddsT * logOdds, Math.Exp(logWeightF), toLogOddsF * logOdds);
			result /= logOdds;
			return result;
		}
 /// <summary>
 /// Evidence message for EP.
 /// </summary>
 /// <param name="not">Constant value for 'not'.</param>
 /// <param name="b">Incoming message from 'b'.</param>
 /// <returns><c>log(int f(x) qnotf(x) dx)</c></returns>
 /// <remarks><para>
 /// The formula for the result is <c>log(int f(x) qnotf(x) dx)</c>
 /// where <c>x = (not,b)</c>.
 /// </para></remarks>
 public static double LogAverageFactor(bool not, Bernoulli b)
 {
     return(not ? b.GetLogProbFalse() : b.GetLogProbTrue());
 }
		/// <summary>
		/// Evidence message for EP.
		/// </summary>
		/// <param name="not">Constant value for 'not'.</param>
		/// <param name="b">Incoming message from 'b'.</param>
		/// <returns><c>log(int f(x) qnotf(x) dx)</c></returns>
		/// <remarks><para>
		/// The formula for the result is <c>log(int f(x) qnotf(x) dx)</c>
		/// where <c>x = (not,b)</c>.
		/// </para></remarks>
		public static double LogAverageFactor(bool not, Bernoulli b)
		{
			return not ? b.GetLogProbFalse() : b.GetLogProbTrue();
		}