/// <summary> /// EP message to 'logistic' /// </summary> /// <param name="logistic">Incoming message from 'logistic'.</param> /// <param name="x">Incoming message from 'x'. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <param name="falseMsg">Buffer 'falseMsg'.</param> /// <returns>The outgoing EP message to the 'logistic' argument</returns> /// <remarks><para> /// The outgoing message is a distribution matching the moments of 'logistic' as the random arguments are varied. /// The formula is <c>proj[p(logistic) sum_(x) p(x) factor(logistic,x)]/p(logistic)</c>. /// </para></remarks> /// <exception cref="ImproperMessageException"><paramref name="x"/> is not a proper distribution</exception> public static Beta LogisticAverageConditional(Beta logistic, [Proper] Gaussian x, Gaussian falseMsg) { if (x.IsPointMass) { return(Beta.PointMass(MMath.Logistic(x.Point))); } if (logistic.IsPointMass || x.IsUniform()) { return(Beta.Uniform()); } double m, v; x.GetMeanAndVariance(out m, out v); if ((logistic.TrueCount == 2 && logistic.FalseCount == 1) || (logistic.TrueCount == 1 && logistic.FalseCount == 2) || logistic.IsUniform()) { // shortcut for the common case // result is a Beta distribution satisfying: // int_p to_p(p) p dp = int_x sigma(x) qnoti(x) dx // int_p to_p(p) p^2 dp = int_x sigma(x)^2 qnoti(x) dx // the second constraint can be rewritten as: // int_p to_p(p) p (1-p) dp = int_x sigma(x) (1 - sigma(x)) qnoti(x) dx // the constraints are the same if we replace p with (1-p) double mean = MMath.LogisticGaussian(m, v); // meanTF = E[p] - E[p^2] double meanTF = MMath.LogisticGaussianDerivative(m, v); double meanSquare = mean - meanTF; return(Beta.FromMeanAndVariance(mean, meanSquare - mean * mean)); } else { // stabilized EP message // choose a normalized distribution to_p such that: // int_p to_p(p) qnoti(p) dp = int_x qnoti(sigma(x)) qnoti(x) dx // int_p to_p(p) p qnoti(p) dp = int_x qnoti(sigma(x)) sigma(x) qnoti(x) dx double logZ = LogAverageFactor(logistic, x, falseMsg) + logistic.GetLogNormalizer(); // log int_x logistic(sigma(x)) N(x;m,v) dx Gaussian post = XAverageConditional(logistic, falseMsg) * x; double mp, vp; post.GetMeanAndVariance(out mp, out vp); double tc1 = logistic.TrueCount - 1; double fc1 = logistic.FalseCount - 1; double Ep; if (tc1 + fc1 == 0) { Beta logistic1 = new Beta(logistic.TrueCount + 1, logistic.FalseCount); double logZp = LogAverageFactor(logistic1, x, falseMsg) + logistic1.GetLogNormalizer(); Ep = Math.Exp(logZp - logZ); } else { // Ep = int_p to_p(p) p qnoti(p) dp / int_p to_p(p) qnoti(p) dp // mp = m + v (a - (a+b) Ep) Ep = (tc1 - (mp - m) / v) / (tc1 + fc1); } return(BetaFromMeanAndIntegral(Ep, logZ, tc1, fc1)); } }
/// <summary> /// EP message to 'logOdds'. /// </summary> /// <param name="sample">Constant value for sample.</param> /// <param name="logOdds">Incoming message from 'logOdds'. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <returns>The outgoing EP message to the 'logOdds' argument.</returns> /// <remarks><para> /// The outgoing message is the moment matched Gaussian approximation to the factor. /// </para></remarks> public static Gaussian LogOddsAverageConditional(bool sample, [SkipIfUniform] Gaussian logOdds) { double m, v; logOdds.GetMeanAndVariance(out m, out v); double s = sample ? 1 : -1; m *= s; if (m + 1.5 * v < -38) { double beta2 = Math.Exp(m + 1.5 * v); return(Gaussian.FromMeanAndVariance(s * (m + v), v * (1 - v * beta2)) / logOdds); } double sigma0 = MMath.LogisticGaussian(m, v); double sigma1 = MMath.LogisticGaussianDerivative(m, v); double sigma2 = MMath.LogisticGaussianDerivative2(m, v); double alpha, beta; alpha = sigma1 / sigma0; if (Double.IsNaN(alpha)) { throw new Exception("alpha is NaN"); } if (m + 2 * v < -19) { beta = Math.Exp(3 * m + 2.5 * v) / (sigma0 * sigma0); } else { //beta = (sigma1*sigma1 - sigma2*sigma0)/(sigma0*sigma0); beta = alpha * alpha - sigma2 / sigma0; } if (Double.IsNaN(beta)) { throw new Exception("beta is NaN"); } double m2 = s * (m + v * alpha); double v2 = v * (1 - v * beta); if (v2 > v) { throw new Exception("v2 > v"); } if (v2 < 0) { throw new Exception("v2 < 0"); } return(Gaussian.FromMeanAndVariance(m2, v2) / logOdds); }
/// <summary> /// Gradient matching VMP message from factor to logOdds variable /// </summary> /// <param name="sample">Constant value for 'sample'.</param> /// <param name="logOdds">Incoming message. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <param name="to_LogOdds">Previous message sent, used for damping</param> /// <returns>The outgoing VMP message.</returns> /// <remarks><para> /// The outgoing message is the Gaussian approximation to the factor which results in the /// same derivatives of the KL(q||p) divergence with respect to the parameters of the posterior /// as if the true factor had been used. /// </para></remarks> /// <exception cref="ImproperMessageException"><paramref name="logOdds"/> is not a proper distribution</exception> public static Gaussian LogOddsAverageLogarithm(bool sample, [Proper, SkipIfUniform] Gaussian logOdds, Gaussian to_LogOdds) { double m, v; // prior mean and variance double s = sample ? 1 : -1; logOdds.GetMeanAndVariance(out m, out v); // E = \int q log f dx // Match gradients double dEbydm = s * MMath.LogisticGaussian(-s * m, v); double dEbydv = -.5 * MMath.LogisticGaussianDerivative(s * m, v); double prec = -2.0 * dEbydv; double meanTimesPrec = m * prec + dEbydm; Gaussian result = Gaussian.FromNatural(meanTimesPrec, prec); double step = Rand.Double() * 0.5; // random damping helps convergence, especially with parallel updates if (step != 1.0) { result.Precision = step * result.Precision + (1 - step) * to_LogOdds.Precision; result.MeanTimesPrecision = step * result.MeanTimesPrecision + (1 - step) * to_LogOdds.MeanTimesPrecision; } return(result); }
// /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="LogisticOp"]/message_doc[@name="LogisticAverageConditional(Beta, Gaussian, Gaussian)"]/*'/> public static Beta LogisticAverageConditional(Beta logistic, [Proper] Gaussian x, Gaussian falseMsg, Gaussian to_x) { if (x.IsPointMass) { return(Beta.PointMass(MMath.Logistic(x.Point))); } if (logistic.IsPointMass || x.IsUniform()) { return(Beta.Uniform()); } Gaussian post = to_x * x; double m, v; post.GetMeanAndVariance(out m, out v); double mean = MMath.LogisticGaussian(m, v); bool useVariance = logistic.IsUniform(); // useVariance gives lower accuracy on tests, but is required for the uniform case if (useVariance) { // meanTF = E[p] - E[p^2] double meanTF = MMath.LogisticGaussianDerivative(m, v); double meanSquare = mean - meanTF; Beta result = Beta.FromMeanAndVariance(mean, meanSquare - mean * mean); result.SetToRatio(result, logistic, true); return(result); } else { double logZ = LogAverageFactor(logistic, x, falseMsg) + logistic.GetLogNormalizer(); // log int_x logistic(sigma(x)) N(x;m,v) dx double tc1 = logistic.TrueCount - 1; double fc1 = logistic.FalseCount - 1; return(BetaFromMeanAndIntegral(mean, logZ, tc1, fc1)); } }
/// <summary>EP message to <c>logOdds</c>.</summary> /// <param name="sample">Constant value for <c>sample</c>.</param> /// <param name="logOdds">Incoming message from <c>logOdds</c>. Must be a proper distribution. If uniform, the result will be uniform.</param> /// <returns>The outgoing EP message to the <c>logOdds</c> argument.</returns> /// <remarks> /// <para>The outgoing message is the factor viewed as a function of <c>logOdds</c> conditioned on the given values.</para> /// </remarks> /// <exception cref="ImproperMessageException"> /// <paramref name="logOdds" /> is not a proper distribution.</exception> public static Gaussian LogOddsAverageConditional(bool sample, [SkipIfUniform] Gaussian logOdds) { double m, v; logOdds.GetMeanAndVariance(out m, out v); double s = sample ? 1 : -1; m *= s; // catch cases when sigma0 would evaluate to 0 if (m + 1.5 * v < -38) // this check catches sigma0=0 when v <= 200 { double beta2 = Math.Exp(m + 1.5 * v); return(Gaussian.FromMeanAndVariance(s * (m + v), v * (1 - v * beta2)) / logOdds); } else if (m + v < 0) { // Factor out exp(m+v/2) in the following formulas: // sigma(m,v) = exp(m+v/2)(1-sigma(m+v,v)) // sigma'(m,v) = d/dm sigma(m,v) = exp(m+v/2)(1-sigma(m+v,v) - sigma'(m+v,v)) // sigma''(m,v) = d/dm sigma'(m,v) = exp(m+v/2)(1-sigma(m+v,v) - 2 sigma'(m+v,v) - sigma''(m+v,v)) // This approach is always safe if sigma(-m-v,v)>0, which is guaranteed by m+v<0 double sigma0 = MMath.LogisticGaussian(-m - v, v); double sd = MMath.LogisticGaussianDerivative(m + v, v); double sigma1 = sigma0 - sd; double sigma2 = sigma1 - sd - MMath.LogisticGaussianDerivative2(m + v, v); double alpha = sigma1 / sigma0; // 1 - sd/sigma0 if (Double.IsNaN(alpha)) { throw new Exception("alpha is NaN"); } double beta = alpha * alpha - sigma2 / sigma0; if (Double.IsNaN(beta)) { throw new Exception("beta is NaN"); } return(GaussianProductOp_Laplace.GaussianFromAlphaBeta(logOdds, s * alpha, beta, ForceProper)); } else if (v > 1488 && m < 0) { double sigma0 = MMath.LogisticGaussianRatio(m, v, 0); double sigma1 = MMath.LogisticGaussianRatio(m, v, 1); double sigma2 = MMath.LogisticGaussianRatio(m, v, 2); double alpha, beta; alpha = sigma1 / sigma0; if (Double.IsNaN(alpha)) { throw new Exception("alpha is NaN"); } beta = alpha * alpha - sigma2 / sigma0; if (Double.IsNaN(beta)) { throw new Exception("beta is NaN"); } return(GaussianProductOp_Laplace.GaussianFromAlphaBeta(logOdds, s * alpha, beta, ForceProper)); } else { // the following code only works when sigma0 > 0 // sigm0=0 can only happen here if v > 1488 double sigma0 = MMath.LogisticGaussian(m, v); double sigma1 = MMath.LogisticGaussianDerivative(m, v); double sigma2 = MMath.LogisticGaussianDerivative2(m, v); double alpha, beta; alpha = sigma1 / sigma0; if (Double.IsNaN(alpha)) { throw new Exception("alpha is NaN"); } if (m + 2 * v < -19) { beta = Math.Exp(3 * m + 2.5 * v) / (sigma0 * sigma0); } else { //beta = (sigma1*sigma1 - sigma2*sigma0)/(sigma0*sigma0); beta = alpha * alpha - sigma2 / sigma0; } if (Double.IsNaN(beta)) { throw new Exception("beta is NaN"); } return(GaussianProductOp_Laplace.GaussianFromAlphaBeta(logOdds, s * alpha, beta, ForceProper)); } }