//-- VMP ------------------------------------------------------------------------------------------------- /// <summary> /// Evidence message for VMP /// </summary> /// <returns>Zero</returns> /// <remarks><para> /// The formula for the result is <c>log(factor(logistic,x))</c>. /// Adding up these values across all factors and variables gives the log-evidence estimate for VMP. /// </para></remarks> //[Skip] public static double AverageLogFactor([Proper, SkipIfUniform] Gaussian x, Beta logistic, Beta to_logistic) { double m, v; x.GetMeanAndVariance(out m, out v); double l1pe = v == 0 ? MMath.Log1PlusExp(m) : MMath.Log1PlusExpGaussian(m, v); return((logistic.TrueCount - 1.0) * (m - l1pe) + (logistic.FalseCount - 1.0) * (-l1pe) - logistic.GetLogNormalizer() - to_logistic.GetAverageLog(logistic)); }
/// <summary> /// Evidence message for VMP. /// </summary> /// <param name="sample">Incoming message from sample</param> /// <param name="logOdds">Incoming message from logOdds</param> /// <returns><c>sum_x marginal(x)*log(factor(x))</c></returns> /// <remarks><para> /// The formula for the result is <c>int log(f(x)) q(x) dx</c> /// where <c>x = (sample,logOdds)</c>. /// </para></remarks> public static double AverageLogFactor(Beta logistic, [Proper, SkipIfUniform] Gaussian x, Beta to_logistic, double a) { double b = logistic.FalseCount; double scale = logistic.TrueCount + b - 2; double shift = -(b - 1); double m, v; x.GetMeanAndVariance(out m, out v); double boundOnLog1PlusExp = a * a * v / 2.0 + MMath.Log1PlusExp(m + (1.0 - 2.0 * a) * v / 2.0); double boundOnLogSigma = m - boundOnLog1PlusExp; return(scale * boundOnLogSigma + shift * m - logistic.GetLogNormalizer() - to_logistic.GetAverageLog(logistic)); }
/// <summary> /// Computes the Bernoulli gating function in the log-odds domain. /// </summary> /// <param name="x"></param> /// <param name="gate"></param> /// <returns></returns> /// <remarks> /// The Bernoulli gating function is x if gate = -infinity and 0 if gate = infinity. /// It is one of the messages sent by a logical OR factor. /// In the log-odds domain, this is: /// log (1 + exp(-gate))/(1 + exp(-gate-x)) /// </remarks> public static double Gate(double x, double gate) { if (x == -gate) { // avoid subtracting infinities return(-MMath.Ln2 + MMath.Log1PlusExp(-gate)); } else if (gate < 0 && x + gate < 0) { // factor out -gate and gate+x return(x + MMath.Log1PlusExp(gate) - MMath.Log1PlusExp(gate + x)); } else { return(MMath.Log1PlusExp(-gate) - MMath.Log1PlusExp(-gate - x)); } }
/// <summary>Evidence message for VMP.</summary> /// <param name="sample">Incoming message from <c>sample</c>.</param> /// <param name="logOdds">Constant value for <c>logOdds</c>.</param> /// <returns>Average of the factor's log-value across the given argument distributions.</returns> /// <remarks> /// <para>The formula for the result is <c>sum_(sample) p(sample) log(factor(sample,logOdds))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para> /// </remarks> public static double AverageLogFactor(Bernoulli sample, double logOdds) { if (sample.IsPointMass) { return(AverageLogFactor(sample.Point, logOdds)); } // probTrue*log(sigma(logOdds)) + probFalse*log(sigma(-logOdds)) // = -log(1+exp(-logOdds)) + probFalse*(-logOdds) // = probTrue*logOdds - log(1+exp(logOdds)) if (logOdds >= 0) { double probFalse = sample.GetProbFalse(); return(-probFalse * logOdds - MMath.Log1PlusExp(-logOdds)); } else { double probTrue = sample.GetProbTrue(); return(probTrue * logOdds - MMath.Log1PlusExp(logOdds)); } }
public static double AverageLogFactor(Bernoulli sample, Gaussian logOdds) { // This is the non-conjugate VMP update using the Saul and Jordan (1999) bound. double m, v; logOdds.GetMeanAndVariance(out m, out v); double a = 0.5; // TODO: use a buffer to store the value of 'a', so it doesn't need to be re-optimised each time. for (int iter = 0; iter < 10; iter++) { double aOld = a; a = MMath.Logistic(m + (1 - 2 * a) * v * 0.5); if (Math.Abs(a - aOld) < 1e-8) { break; } } return(sample.GetProbTrue() * m - .5 * a * a * v - MMath.Log1PlusExp(m + (1 - 2 * a) * v * 0.5)); }
/// <summary> /// Gets the log normalizer of the distribution /// </summary> /// <returns>This equals -log(1-p)</returns> public double GetLogNormalizer() { // equivalent to -log(1-p) return(MMath.Log1PlusExp(LogOdds)); }
public static double logSumExpBound(double m, double v, double a) { return(0.5 * v * a * a + MMath.Log1PlusExp(m + (0.5 - a) * v)); }