Ejemplo n.º 1
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="LogisticOp_JJ96"]/message_doc[@name="XAverageLogarithm(Beta, Gaussian, Gaussian)"]/*'/>
        public static Gaussian XAverageLogarithm([SkipIfUniform] Beta logistic, [Proper, SkipIfUniform] Gaussian x, Gaussian result)
        {
            if (logistic.IsPointMass)
            {
                return(LogisticOp.XAverageLogarithm(logistic.Point));
            }
            // f(x) = sigma(x)^(a-1) sigma(-x)^(b-1)
            //      = sigma(x)^(a+b-2) exp(-x(b-1))
            // since sigma(-x) = sigma(x) exp(-x)

            double a     = logistic.TrueCount;
            double b     = logistic.FalseCount;
            double scale = a + b - 2;

            if (scale == 0.0)
            {
                return(Gaussian.Uniform());
            }
            double shift = -(b - 1);
            // sigma(x) >= sigma(t) exp((x-t)/2 - a/2*(x^2 - t^2))
            double m, v;

            x.GetMeanAndVariance(out m, out v);
            double t      = Math.Sqrt(m * m + v);
            double lambda = (t == 0) ? 0.25 : Math.Tanh(t / 2) / (2 * t);

            return(Gaussian.FromNatural(scale * 0.5 + shift, scale * lambda));
        }
Ejemplo n.º 2
0
#pragma warning disable 429
#endif

        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="LogisticOp_SJ99"]/message_doc[@name="XAverageLogarithm(Beta, Gaussian, Gaussian, double)"]/*'/>
        public static Gaussian XAverageLogarithm([SkipIfUniform] Beta logistic, /*[Proper, SkipIfUniform]*/ Gaussian x, Gaussian to_x, double a)
        {
            if (logistic.IsPointMass)
            {
                return(LogisticOp.XAverageLogarithm(logistic.Point));
            }
            // f(x) = sigma(x)^(a-1) sigma(-x)^(b-1)
            //      = sigma(x)^(a+b-2) exp(-x(b-1))
            // since sigma(-x) = sigma(x) exp(-x)

            double scale = logistic.TrueCount + logistic.FalseCount - 2;

            if (scale == 0.0)
            {
                return(Gaussian.Uniform());
            }
            double shift = -(logistic.FalseCount - 1);
            double m, v;

            x.GetMeanAndVariance(out m, out v);
            double sa;

            if (double.IsPositiveInfinity(v))
            {
                a  = 0.5;
                sa = MMath.Logistic(m);
            }
            else
            {
                sa = MMath.Logistic(m + (1 - 2 * a) * v * 0.5);
            }
            double precision = a * a + (1 - 2 * a) * sa;
            // meanTimesPrecision = m*a*a + 1-2*a*sa;
            double meanTimesPrecision = m * precision + 1 - sa;
            //double vf = 1/(a*a + (1-2*a)*sa);
            //double mf = m + vf*(true ? 1-sa : sa);
            //double precision = 1/vf;
            //double meanTimesPrecision = mf*precision;
            Gaussian result = Gaussian.FromNatural(scale * meanTimesPrecision + shift, scale * precision);
            double   step   = (LogisticOp_SJ99.global_step == 0.0) ? 1.0 : (Rand.Double() * LogisticOp_SJ99.global_step);

            // random damping helps convergence, especially with parallel updates
            if (false && !x.IsPointMass)
            {
                // if the update would change the sign of 1-2*sa, send a message to make sa=0.5
                double newPrec = x.Precision - to_x.Precision + result.Precision;
                double newv    = 1 / newPrec;
                double newm    = newv * (x.MeanTimesPrecision - to_x.MeanTimesPrecision + result.MeanTimesPrecision);
                double newarg  = newm + (1 - 2 * a) * newv * 0.5;
                if ((sa < 0.5 && newarg > 0) || (sa > 0.5 && newarg < 0))
                {
                    // send a message to make newarg=0
                    // it is sufficient to make (x.MeanTimesPrecision + step*(result.MeanTimesPrecision - to_x.MeanTimesPrecision) + 0.5-a) = 0
                    double mpOffset   = x.MeanTimesPrecision + 0.5 - a;
                    double precOffset = x.Precision;
                    double mpScale    = result.MeanTimesPrecision - to_x.MeanTimesPrecision;
                    double precScale  = result.Precision - to_x.Precision;
                    double arg        = m + (1 - 2 * a) * v * 0.5;
                    //arg = 0;
                    step = (arg * precOffset - mpOffset) / (mpScale - arg * precScale);
                    //step = (a-0.5-x.MeanTimesPrecision)/(result.MeanTimesPrecision - to_x.MeanTimesPrecision);
                    //Console.WriteLine(step);
                }
            }
            if (step != 1.0)
            {
                result.Precision          = step * result.Precision + (1 - step) * to_x.Precision;
                result.MeanTimesPrecision = step * result.MeanTimesPrecision + (1 - step) * to_x.MeanTimesPrecision;
            }
            return(result);
        }