/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="LogisticOp_JJ96"]/message_doc[@name="XAverageLogarithm(Beta, Gaussian, Gaussian)"]/*'/> public static Gaussian XAverageLogarithm([SkipIfUniform] Beta logistic, [Proper, SkipIfUniform] Gaussian x, Gaussian result) { if (logistic.IsPointMass) { return(LogisticOp.XAverageLogarithm(logistic.Point)); } // f(x) = sigma(x)^(a-1) sigma(-x)^(b-1) // = sigma(x)^(a+b-2) exp(-x(b-1)) // since sigma(-x) = sigma(x) exp(-x) double a = logistic.TrueCount; double b = logistic.FalseCount; double scale = a + b - 2; if (scale == 0.0) { return(Gaussian.Uniform()); } double shift = -(b - 1); // sigma(x) >= sigma(t) exp((x-t)/2 - a/2*(x^2 - t^2)) double m, v; x.GetMeanAndVariance(out m, out v); double t = Math.Sqrt(m * m + v); double lambda = (t == 0) ? 0.25 : Math.Tanh(t / 2) / (2 * t); return(Gaussian.FromNatural(scale * 0.5 + shift, scale * lambda)); }
#pragma warning disable 429 #endif /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="LogisticOp_SJ99"]/message_doc[@name="XAverageLogarithm(Beta, Gaussian, Gaussian, double)"]/*'/> public static Gaussian XAverageLogarithm([SkipIfUniform] Beta logistic, /*[Proper, SkipIfUniform]*/ Gaussian x, Gaussian to_x, double a) { if (logistic.IsPointMass) { return(LogisticOp.XAverageLogarithm(logistic.Point)); } // f(x) = sigma(x)^(a-1) sigma(-x)^(b-1) // = sigma(x)^(a+b-2) exp(-x(b-1)) // since sigma(-x) = sigma(x) exp(-x) double scale = logistic.TrueCount + logistic.FalseCount - 2; if (scale == 0.0) { return(Gaussian.Uniform()); } double shift = -(logistic.FalseCount - 1); double m, v; x.GetMeanAndVariance(out m, out v); double sa; if (double.IsPositiveInfinity(v)) { a = 0.5; sa = MMath.Logistic(m); } else { sa = MMath.Logistic(m + (1 - 2 * a) * v * 0.5); } double precision = a * a + (1 - 2 * a) * sa; // meanTimesPrecision = m*a*a + 1-2*a*sa; double meanTimesPrecision = m * precision + 1 - sa; //double vf = 1/(a*a + (1-2*a)*sa); //double mf = m + vf*(true ? 1-sa : sa); //double precision = 1/vf; //double meanTimesPrecision = mf*precision; Gaussian result = Gaussian.FromNatural(scale * meanTimesPrecision + shift, scale * precision); double step = (LogisticOp_SJ99.global_step == 0.0) ? 1.0 : (Rand.Double() * LogisticOp_SJ99.global_step); // random damping helps convergence, especially with parallel updates if (false && !x.IsPointMass) { // if the update would change the sign of 1-2*sa, send a message to make sa=0.5 double newPrec = x.Precision - to_x.Precision + result.Precision; double newv = 1 / newPrec; double newm = newv * (x.MeanTimesPrecision - to_x.MeanTimesPrecision + result.MeanTimesPrecision); double newarg = newm + (1 - 2 * a) * newv * 0.5; if ((sa < 0.5 && newarg > 0) || (sa > 0.5 && newarg < 0)) { // send a message to make newarg=0 // it is sufficient to make (x.MeanTimesPrecision + step*(result.MeanTimesPrecision - to_x.MeanTimesPrecision) + 0.5-a) = 0 double mpOffset = x.MeanTimesPrecision + 0.5 - a; double precOffset = x.Precision; double mpScale = result.MeanTimesPrecision - to_x.MeanTimesPrecision; double precScale = result.Precision - to_x.Precision; double arg = m + (1 - 2 * a) * v * 0.5; //arg = 0; step = (arg * precOffset - mpOffset) / (mpScale - arg * precScale); //step = (a-0.5-x.MeanTimesPrecision)/(result.MeanTimesPrecision - to_x.MeanTimesPrecision); //Console.WriteLine(step); } } if (step != 1.0) { result.Precision = step * result.Precision + (1 - step) * to_x.Precision; result.MeanTimesPrecision = step * result.MeanTimesPrecision + (1 - step) * to_x.MeanTimesPrecision; } return(result); }