/// <summary> /// Update the buffer 'falseMsg' /// </summary> /// <param name="logistic">Incoming message from 'logistic'. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <param name="x">Incoming message from 'x'. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <param name="falseMsg">Buffer 'falseMsg'.</param> /// <returns>New value of buffer 'falseMsg'</returns> /// <remarks><para> /// /// </para></remarks> /// <exception cref="ImproperMessageException"><paramref name="logistic"/> is not a proper distribution</exception> /// <exception cref="ImproperMessageException"><paramref name="x"/> is not a proper distribution</exception> public static Gaussian FalseMsg([SkipIfUniform] Beta logistic, [Proper] Gaussian x, Gaussian falseMsg) { // falseMsg approximates sigma(-x) // logistic(sigma(x)) N(x;m,v) // = sigma(x)^(a-1) sigma(-x)^(b-1) N(x;m,v) // = e^((a-1)x) sigma(-x)^(a+b-2) N(x;m,v) // = sigma(-x)^(a+b-2) N(x;m+(a-1)v,v) exp((a-1)m + (a-1)^2 v/2) // = sigma(-x) (prior) // where prior = sigma(-x)^(a+b-3) N(x;m+(a-1)v,v) double tc1 = logistic.TrueCount - 1; double fc1 = logistic.FalseCount - 1; double m, v; x.GetMeanAndVariance(out m, out v); if (tc1 + fc1 == 0) { falseMsg.SetToUniform(); return(falseMsg); } else if (tc1 + fc1 < 0) { // power EP update, using 1/sigma(-x) as the factor Gaussian prior = new Gaussian(m + tc1 * v, v) * (falseMsg ^ (tc1 + fc1 + 1)); double mprior, vprior; prior.GetMeanAndVariance(out mprior, out vprior); // posterior moments can be computed exactly double w = MMath.Logistic(mprior + 0.5 * vprior); Gaussian post = new Gaussian(mprior + w * vprior, vprior * (1 + w * (1 - w) * vprior)); return(prior / post); } else { // power EP update Gaussian prior = new Gaussian(m + tc1 * v, v) * (falseMsg ^ (tc1 + fc1 - 1)); Gaussian newMsg = BernoulliFromLogOddsOp.LogOddsAverageConditional(false, prior); //Console.WriteLine("prior = {0}, falseMsg = {1}, newMsg = {2}", prior, falseMsg, newMsg); if (true) { // adaptive damping scheme Gaussian ratio = newMsg / falseMsg; if ((ratio.MeanTimesPrecision < 0 && prior.MeanTimesPrecision > 0) || (ratio.MeanTimesPrecision > 0 && prior.MeanTimesPrecision < 0)) { // if the update would change the sign of the mean, take a fractional step so that the new prior has exactly zero mean // newMsg = falseMsg * (ratio^step) // newPrior = prior * (ratio^step)^(tc1+fc1-1) // 0 = prior.mp + ratio.mp*step*(tc1+fc1-1) double step = -prior.MeanTimesPrecision / (ratio.MeanTimesPrecision * (tc1 + fc1 - 1)); if (step > 0 && step < 1) { newMsg = falseMsg * (ratio ^ step); // check that newPrior has zero mean //Gaussian newPrior = prior * ((ratio^step)^(tc1+fc1-1)); //Console.WriteLine(newPrior); } } } return(newMsg); } }
/// <summary> /// VMP message to 'logistic' /// </summary> /// <param name="x">Incoming message from 'x'. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <returns>The outgoing VMP message to the 'logistic' argument</returns> /// <remarks><para> /// The outgoing message is a distribution matching the moments of 'logistic' as the random arguments are varied. /// The formula is <c>proj[sum_(x) p(x) factor(logistic,x)]</c>. /// </para></remarks> /// <exception cref="ImproperMessageException"><paramref name="x"/> is not a proper distribution</exception> public static Beta LogisticAverageLogarithm([Proper] Gaussian x) { double m, v; x.GetMeanAndVariance(out m, out v); #if true // for consistency with XAverageLogarithm double eLogOneMinusP = BernoulliFromLogOddsOp.AverageLogFactor(false, x); #else // E[log (1-sigma(x))] = E[log sigma(-x)] = -E[log(1+exp(x))] double eLogOneMinusP = -MMath.Log1PlusExpGaussian(m, v); #endif // E[log sigma(x)] = -E[log(1+exp(-x))] = -E[log(1+exp(x))-x] = -E[log(1+exp(x))] + E[x] double eLogP = eLogOneMinusP + m; return(Beta.FromMeanLogs(eLogP, eLogOneMinusP)); }
/// <summary> /// Evidence message for VMP /// </summary> /// <param name="sample">Constant value for 'sample'.</param> /// <param name="logProbs">Incoming message from 'logProbs'. Must be a proper distribution. If any element is uniform, the result will be uniform.</param> /// <returns>Average of the factor's log-value across the given argument distributions</returns> /// <remarks><para> /// The formula for the result is <c>sum_(logProbs) p(logProbs) log(factor(sample,logProbs))</c>. /// Adding up these values across all factors and variables gives the log-evidence estimate for VMP. /// </para></remarks> /// <exception cref="ImproperMessageException"><paramref name="logProbs"/> is not a proper distribution</exception> public static double AverageLogFactor(int sample, [Proper] IList <Gaussian> logProbs) { double result = 0; double ms, vs; logProbs[sample].GetMeanAndVariance(out ms, out vs); for (int k = 0; k < logProbs.Count; k++) { if (k == sample) { continue; } double m, v; logProbs[k].GetMeanAndVariance(out m, out v); Gaussian logProb = new Gaussian(ms - m, vs + v); result += BernoulliFromLogOddsOp.AverageLogFactor(true, logProb); } return(result); }
/// <summary> /// VMP message to 'x' /// </summary> /// <param name="logistic">Incoming message from 'logistic'. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <param name="x">Incoming message from 'x'. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <param name="to_X">Previous outgoing message to 'X'.</param> /// <returns>The outgoing VMP message to the 'x' argument</returns> /// <remarks><para> /// The outgoing message is the factor viewed as a function of 'x' with 'logistic' integrated out. /// The formula is <c>sum_logistic p(logistic) factor(logistic,x)</c>. /// </para></remarks> /// <exception cref="ImproperMessageException"><paramref name="logistic"/> is not a proper distribution</exception> /// <exception cref="ImproperMessageException"><paramref name="x"/> is not a proper distribution</exception> public static Gaussian XAverageLogarithm([SkipIfUniform] Beta logistic, [Proper, SkipIfUniform] Gaussian x, Gaussian to_X) { if (logistic.IsPointMass) { return(XAverageLogarithm(logistic.Point)); } // f(x) = sigma(x)^(a-1) sigma(-x)^(b-1) // = sigma(x)^(a+b-2) exp(-x(b-1)) // since sigma(-x) = sigma(x) exp(-x) double a = logistic.TrueCount; double b = logistic.FalseCount; double scale = a + b - 2; if (scale == 0.0) { return(Gaussian.Uniform()); } double shift = -(b - 1); Gaussian toLogOddsPrev = Gaussian.FromNatural((to_X.MeanTimesPrecision - shift) / scale, to_X.Precision / scale); Gaussian toLogOdds = BernoulliFromLogOddsOp.LogOddsAverageLogarithm(true, x, toLogOddsPrev); return(Gaussian.FromNatural(scale * toLogOdds.MeanTimesPrecision + shift, scale * toLogOdds.Precision)); }