/// <summary>
        /// Update the buffer 'falseMsg'
        /// </summary>
        /// <param name="logistic">Incoming message from 'logistic'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
        /// <param name="x">Incoming message from 'x'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
        /// <param name="falseMsg">Buffer 'falseMsg'.</param>
        /// <returns>New value of buffer 'falseMsg'</returns>
        /// <remarks><para>
        ///
        /// </para></remarks>
        /// <exception cref="ImproperMessageException"><paramref name="logistic"/> is not a proper distribution</exception>
        /// <exception cref="ImproperMessageException"><paramref name="x"/> is not a proper distribution</exception>
        public static Gaussian FalseMsg([SkipIfUniform] Beta logistic, [Proper] Gaussian x, Gaussian falseMsg)
        {
            // falseMsg approximates sigma(-x)
            // logistic(sigma(x)) N(x;m,v)
            // = sigma(x)^(a-1) sigma(-x)^(b-1) N(x;m,v)
            // = e^((a-1)x) sigma(-x)^(a+b-2) N(x;m,v)
            // = sigma(-x)^(a+b-2) N(x;m+(a-1)v,v) exp((a-1)m + (a-1)^2 v/2)
            // = sigma(-x) (prior)
            // where prior = sigma(-x)^(a+b-3) N(x;m+(a-1)v,v)
            double tc1 = logistic.TrueCount - 1;
            double fc1 = logistic.FalseCount - 1;
            double m, v;

            x.GetMeanAndVariance(out m, out v);
            if (tc1 + fc1 == 0)
            {
                falseMsg.SetToUniform();
                return(falseMsg);
            }
            else if (tc1 + fc1 < 0)
            {
                // power EP update, using 1/sigma(-x) as the factor
                Gaussian prior = new Gaussian(m + tc1 * v, v) * (falseMsg ^ (tc1 + fc1 + 1));
                double   mprior, vprior;
                prior.GetMeanAndVariance(out mprior, out vprior);
                // posterior moments can be computed exactly
                double   w    = MMath.Logistic(mprior + 0.5 * vprior);
                Gaussian post = new Gaussian(mprior + w * vprior, vprior * (1 + w * (1 - w) * vprior));
                return(prior / post);
            }
            else
            {
                // power EP update
                Gaussian prior  = new Gaussian(m + tc1 * v, v) * (falseMsg ^ (tc1 + fc1 - 1));
                Gaussian newMsg = BernoulliFromLogOddsOp.LogOddsAverageConditional(false, prior);
                //Console.WriteLine("prior = {0}, falseMsg = {1}, newMsg = {2}", prior, falseMsg, newMsg);
                if (true)
                {
                    // adaptive damping scheme
                    Gaussian ratio = newMsg / falseMsg;
                    if ((ratio.MeanTimesPrecision < 0 && prior.MeanTimesPrecision > 0) ||
                        (ratio.MeanTimesPrecision > 0 && prior.MeanTimesPrecision < 0))
                    {
                        // if the update would change the sign of the mean, take a fractional step so that the new prior has exactly zero mean
                        // newMsg = falseMsg * (ratio^step)
                        // newPrior = prior * (ratio^step)^(tc1+fc1-1)
                        // 0 = prior.mp + ratio.mp*step*(tc1+fc1-1)
                        double step = -prior.MeanTimesPrecision / (ratio.MeanTimesPrecision * (tc1 + fc1 - 1));
                        if (step > 0 && step < 1)
                        {
                            newMsg = falseMsg * (ratio ^ step);
                            // check that newPrior has zero mean
                            //Gaussian newPrior = prior * ((ratio^step)^(tc1+fc1-1));
                            //Console.WriteLine(newPrior);
                        }
                    }
                }
                return(newMsg);
            }
        }
        /// <summary>
        /// VMP message to 'logistic'
        /// </summary>
        /// <param name="x">Incoming message from 'x'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
        /// <returns>The outgoing VMP message to the 'logistic' argument</returns>
        /// <remarks><para>
        /// The outgoing message is a distribution matching the moments of 'logistic' as the random arguments are varied.
        /// The formula is <c>proj[sum_(x) p(x) factor(logistic,x)]</c>.
        /// </para></remarks>
        /// <exception cref="ImproperMessageException"><paramref name="x"/> is not a proper distribution</exception>
        public static Beta LogisticAverageLogarithm([Proper] Gaussian x)
        {
            double m, v;

            x.GetMeanAndVariance(out m, out v);

#if true
            // for consistency with XAverageLogarithm
            double eLogOneMinusP = BernoulliFromLogOddsOp.AverageLogFactor(false, x);
#else
            // E[log (1-sigma(x))] = E[log sigma(-x)] = -E[log(1+exp(x))]
            double eLogOneMinusP = -MMath.Log1PlusExpGaussian(m, v);
#endif
            // E[log sigma(x)] = -E[log(1+exp(-x))] = -E[log(1+exp(x))-x] = -E[log(1+exp(x))] + E[x]
            double eLogP = eLogOneMinusP + m;
            return(Beta.FromMeanLogs(eLogP, eLogOneMinusP));
        }
예제 #3
0
        /// <summary>
        /// Evidence message for VMP
        /// </summary>
        /// <param name="sample">Constant value for 'sample'.</param>
        /// <param name="logProbs">Incoming message from 'logProbs'. Must be a proper distribution.  If any element is uniform, the result will be uniform.</param>
        /// <returns>Average of the factor's log-value across the given argument distributions</returns>
        /// <remarks><para>
        /// The formula for the result is <c>sum_(logProbs) p(logProbs) log(factor(sample,logProbs))</c>.
        /// Adding up these values across all factors and variables gives the log-evidence estimate for VMP.
        /// </para></remarks>
        /// <exception cref="ImproperMessageException"><paramref name="logProbs"/> is not a proper distribution</exception>
        public static double AverageLogFactor(int sample, [Proper] IList <Gaussian> logProbs)
        {
            double result = 0;
            double ms, vs;

            logProbs[sample].GetMeanAndVariance(out ms, out vs);
            for (int k = 0; k < logProbs.Count; k++)
            {
                if (k == sample)
                {
                    continue;
                }
                double m, v;
                logProbs[k].GetMeanAndVariance(out m, out v);
                Gaussian logProb = new Gaussian(ms - m, vs + v);
                result += BernoulliFromLogOddsOp.AverageLogFactor(true, logProb);
            }
            return(result);
        }
        /// <summary>
        /// VMP message to 'x'
        /// </summary>
        /// <param name="logistic">Incoming message from 'logistic'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
        /// <param name="x">Incoming message from 'x'. Must be a proper distribution.  If uniform, the result will be uniform.</param>
        /// <param name="to_X">Previous outgoing message to 'X'.</param>
        /// <returns>The outgoing VMP message to the 'x' argument</returns>
        /// <remarks><para>
        /// The outgoing message is the factor viewed as a function of 'x' with 'logistic' integrated out.
        /// The formula is <c>sum_logistic p(logistic) factor(logistic,x)</c>.
        /// </para></remarks>
        /// <exception cref="ImproperMessageException"><paramref name="logistic"/> is not a proper distribution</exception>
        /// <exception cref="ImproperMessageException"><paramref name="x"/> is not a proper distribution</exception>
        public static Gaussian XAverageLogarithm([SkipIfUniform] Beta logistic, [Proper, SkipIfUniform] Gaussian x, Gaussian to_X)
        {
            if (logistic.IsPointMass)
            {
                return(XAverageLogarithm(logistic.Point));
            }
            // f(x) = sigma(x)^(a-1) sigma(-x)^(b-1)
            //      = sigma(x)^(a+b-2) exp(-x(b-1))
            // since sigma(-x) = sigma(x) exp(-x)

            double a     = logistic.TrueCount;
            double b     = logistic.FalseCount;
            double scale = a + b - 2;

            if (scale == 0.0)
            {
                return(Gaussian.Uniform());
            }
            double   shift         = -(b - 1);
            Gaussian toLogOddsPrev = Gaussian.FromNatural((to_X.MeanTimesPrecision - shift) / scale, to_X.Precision / scale);
            Gaussian toLogOdds     = BernoulliFromLogOddsOp.LogOddsAverageLogarithm(true, x, toLogOddsPrev);

            return(Gaussian.FromNatural(scale * toLogOdds.MeanTimesPrecision + shift, scale * toLogOdds.Precision));
        }