コード例 #1
0
ファイル: PlusGamma.cs プロジェクト: ScriptBox21/dotnet-infer
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaVmpOp"]/message_doc[@name="AAverageLogarithm(GammaPower, GammaPower, GammaPower)"]/*'/>
        public static GammaPower AAverageLogarithm([SkipIfUniform] GammaPower sum, [Proper] GammaPower a, [Proper] GammaPower b, GammaPower to_a, GammaPower to_b)
        {
            // f = int_sum sum^(ss/c -1)*exp(-sr*sum^(1/c))*delta(sum = a+b) dsum
            //   = (a+b)^(ss/c -1)*exp(-sr*(a+b)^(1/c))
            // log(f) = (ss/c -1)*log(a+b) - sr*(a+b)^(1/c)
            // apply a lower bound:
            // log(a+b) >= p*log(a) + (1-p)*log(b) - p*log(p) - (1-p)*log(1-p)
            // optimal p = exp(a)/(exp(a)+exp(b)) if (a,b) are fixed
            // optimal p = exp(E[log(a)])/(exp(E[log(a)])+exp(E[log(b)]))  if (a,b) are random
            // This generalizes the bound used by Cemgil (2008).
            // (a+b)^(1/c) = (a*q/q + b*(1-q)/(1-q))^(1/c)
            //             <= q*(a/q)^(1/c) + (1-q)*(b/(1-q))^(1/c)  if c < 0
            //             = q^(1-1/c)*a^(1/c) + (1-q)^(1-1/c)*b^(1/c)
            // d/dq = (1-1/c)*(q^(-1/c)*a^(1/c) - (1-q)^(-1/c)*b^(1/c))
            // optimal q = a/(a + b) if (a,b) are fixed
            // optimal q = E[a^(1/c)]^c/(E[a^(1/c)]^c + E[b^(1/c)]^c) if (a,b) are random
            // The message to A has shape (ss-c)*p + c and rate sr*q^(1-1/c).
            // If sum is a point mass, then the message to A is pointmass(s*p^c*q^(1-c))
            if (a.Power != sum.Power)
            {
                throw new NotSupportedException($"a.Power ({a.Power}) != sum.Power ({sum.Power})");
            }
            double     x     = sum.Shape - sum.Power;
            GammaPower aPost = a * to_a;

            if (aPost.IsUniform())
            {
                return(sum);
            }
            GammaPower bPost = b * to_b;

            if (bPost.IsUniform())
            {
                if (sum.IsPointMass)
                {
                    return(sum);
                }
                return(GammaPower.FromShapeAndRate(sum.Power, sum.Rate, sum.Power));
            }
            double ma     = Math.Exp(aPost.GetMeanLog());
            double mb     = Math.Exp(bPost.GetMeanLog());
            double denom  = ma + mb;
            double p      = (denom == 0) ? 0.5 : ma / (ma + mb);
            double m      = x * p;
            double mac    = Math.Exp(a.Power * aPost.GetLogMeanPower(1 / a.Power));
            double mbc    = Math.Exp(b.Power * bPost.GetLogMeanPower(1 / b.Power));
            double denom2 = mac + mbc;
            double q      = (denom2 == 0) ? 0.5 : mac / (mac + mbc);

            if (sum.IsPointMass)
            {
                return(GammaPower.PointMass(sum.Point * Math.Pow(p / q, sum.Power) * q, sum.Power));
            }
            return(GammaPower.FromShapeAndRate(m + sum.Power, sum.Rate * Math.Pow(q, 1 - 1 / sum.Power), sum.Power));
        }