예제 #1
0
        public void ExpOpGammaPower_PointExp()
        {
            double power = -1;
            double vd    = 1e-4;

            vd = 1e-3;
            Gaussian   d               = new Gaussian(0, vd);
            Gaussian   uniform         = Gaussian.Uniform();
            GammaPower expPoint        = GammaPower.PointMass(2, power);
            GammaPower to_exp_point    = ExpOp.ExpAverageConditional(expPoint, d, uniform);
            Gaussian   to_d_point      = ExpOp.DAverageConditional(expPoint, d, uniform);
            double     to_exp_oldError = double.PositiveInfinity;
            double     to_d_oldError   = double.PositiveInfinity;

            for (int i = 0; i < 100; i++)
            {
                double     ve           = System.Math.Pow(10, -i);
                GammaPower exp          = GammaPower.FromMeanAndVariance(2, ve, power);
                GammaPower to_exp       = ExpOp.ExpAverageConditional(exp, d, uniform);
                Gaussian   to_d         = ExpOp.DAverageConditional(exp, d, uniform);
                double     to_exp_error = to_exp.MaxDiff(to_exp_point);
                double     to_d_error   = System.Math.Abs(to_d.GetMean() - to_d_point.GetMean());
                Trace.WriteLine($"ve={ve}: to_exp={to_exp} error={to_exp_error} to_d={to_d} error={to_d_error}");
                Assert.True(to_exp_error <= to_exp_oldError);
                to_exp_oldError = to_exp_error;
                Assert.True(to_d_error <= to_d_oldError);
                to_d_oldError = to_d_error;
            }
        }
예제 #2
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="SumAverageConditional(GammaPower, GammaPower)"]/*'/>
        public static GammaPower SumAverageConditional([SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result)
        {
            a.GetMeanAndVariance(out double aMean, out double aVariance);
            b.GetMeanAndVariance(out double bMean, out double bVariance);
            double mean     = aMean + bMean;
            double variance = aVariance + bVariance;

            return(GammaPower.FromMeanAndVariance(mean, variance, result.Power));
        }
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaPowerProductOp_Laplace"]/message_doc[@name="BAverageConditional(GammaPower, GammaPower, GammaPower, Gamma, GammaPower)"]/*'/>
        public static GammaPower BAverageConditional([SkipIfUniform] GammaPower product, [Proper] GammaPower A, [Proper] GammaPower B, Gamma q, GammaPower result)
        {
            if (B.Shape < A.Shape)
            {
                return(AAverageConditional(product, B, A, q, result));
            }
            if (A.IsPointMass)
            {
                return(GammaProductOp.BAverageConditional(product, A.Point, result));
            }
            if (B.IsPointMass)
            {
                return(GammaPower.Uniform(result.Power)); // TODO
            }
            if (product.IsUniform())
            {
                return(product);
            }
            if (q.IsUniform())
            {
                q = Q(product, A, B);
            }
            double     qPoint = q.GetMean();
            GammaPower bMarginal;
            // threshold ensures 6/qPoint^4 does not overflow
            double threshold = Math.Sqrt(Math.Sqrt(6 / double.MaxValue));

            if (result.Power < 0 && qPoint > threshold)
            {
                double iqMean, iqVariance;
                GetIQMoments(product, A, q, qPoint, out iqMean, out iqVariance);
                GammaPower iqMarginal = GammaPower.FromMeanAndVariance(iqMean, iqVariance, -1);
                bMarginal = GammaPower.FromShapeAndRate(iqMarginal.Shape, iqMarginal.Rate, result.Power);
            }
            else
            {
                // B.Shape >= A.Shape therefore Q is the approximate distribution of B^(1/B.Power).
                // We compute the approximate moments of q = b^(1/b.Power) to get a Gamma distribution and then raise to B.Power.
                double qMean, qVariance;
                GetQMoments(product, A, q, qPoint, out qMean, out qVariance);
                bMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(qMean, qVariance), result.Power);
            }
            result.SetToRatio(bMarginal, B, GammaProductOp_Laplace.ForceProper);
            return(result);
        }
예제 #4
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="AAverageConditional(GammaPower, GammaPower)"]/*'/>
        public static GammaPower AAverageConditional([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result)
        {
            if (sum.IsUniform())
            {
                return(sum);
            }
            sum.GetMeanAndVariance(out double sumMean, out double sumVariance);
            b.GetMeanAndVariance(out double bMean, out double bVariance);
            double rMean     = Math.Max(0, sumMean - bMean);
            double rVariance = sumVariance + bVariance;
            double aVariance = a.GetVariance();

            if (rVariance > aVariance)
            {
                if (sum.Power == 1)
                {
                    GetGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance);
                    mean     += b.GetMean();
                    variance += b.GetVariance();
                    GetGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr);
                    GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ);
                    return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ));
                }
                else if (sum.Power == -1)
                {
                    GetInverseGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance);
                    mean     += b.GetMean();
                    variance += b.GetVariance();
                    if (variance > double.MaxValue)
                    {
                        return(GammaPower.Uniform(a.Power));                            //throw new NotSupportedException();
                    }
                    GetInverseGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr);
                    if (sum.IsPointMass && sum.Point == 0)
                    {
                        return(GammaPower.PointMass(0, a.Power));
                    }
                    GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ);
                    return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ));
                }
            }
            return(GammaPower.FromMeanAndVariance(rMean, rVariance, result.Power));
        }
예제 #5
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaVmpOp"]/message_doc[@name="SumAverageLogarithm(GammaPower, GammaPower, GammaPower)"]/*'/>
        public static GammaPower SumAverageLogarithm([SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result)
        {
            if (a.IsUniform() || b.IsUniform())
            {
                result.SetToUniform();
                return(result);
            }
            if (result.Power == -1)
            {
                bool aHasInfiniteMean = (a.Power == -1 && a.Shape <= 1);
                bool bHasInfiniteMean = (b.Power == -1 && b.Shape <= 1);
                if (aHasInfiniteMean)
                {
                    if (bHasInfiniteMean)
                    {
                        return(InvGammaFromShapeAndMeanInverse(Math.Min(a.Shape, b.Shape), MeanInverseOfSum(a, b)));
                    }
                    else
                    {
                        return(InvGammaFromShapeAndMeanInverse(a.Shape, MeanInverseOfSum(a, b)));
                    }
                }
                else if (bHasInfiniteMean)
                {
                    return(InvGammaFromShapeAndMeanInverse(b.Shape, MeanInverseOfSum(a, b)));
                }
            }
            a.GetMeanAndVariance(out double aMean, out double aVariance);
            b.GetMeanAndVariance(out double bMean, out double bVariance);
            double mean     = aMean + bMean;
            double variance = aVariance + bVariance;

            if (result.Power == -1 && variance > double.MaxValue && false)
            {
                // mean is finite
                return(InvGammaFromMeanAndMeanInverse(mean, MeanInverseOfSum(a, b)));
            }
            return(GammaPower.FromMeanAndVariance(mean, variance, result.Power));
        }
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaPowerProductOp_Laplace"]/message_doc[@name="ProductAverageConditional(GammaPower, GammaPower, GammaPower, Gamma, GammaPower)"]/*'/>
        public static GammaPower ProductAverageConditional(GammaPower product, [Proper] GammaPower A, [SkipIfUniform] GammaPower B, Gamma q, GammaPower result)
        {
            if (B.Shape < A.Shape)
            {
                return(ProductAverageConditional(product, B, A, q, result));
            }
            if (B.IsPointMass)
            {
                return(GammaProductOp.ProductAverageConditional(A, B.Point));
            }
            if (B.IsUniform())
            {
                return(GammaPower.Uniform(result.Power));
            }
            if (A.IsPointMass)
            {
                return(GammaProductOp.ProductAverageConditional(A.Point, B));
            }
            if (product.IsPointMass)
            {
                return(GammaPower.Uniform(result.Power)); // TODO
            }
            if (A.Power != product.Power)
            {
                throw new NotSupportedException($"A.Power ({A.Power}) != product.Power ({product.Power})");
            }
            if (B.Power != product.Power)
            {
                throw new NotSupportedException($"B.Power ({B.Power}) != product.Power ({product.Power})");
            }
            if (A.Rate == 0)
            {
                if (B.Rate == 0)
                {
                    return(GammaPower.FromShapeAndRate(Math.Min(A.Shape, B.Shape), 0, result.Power));
                }
                else
                {
                    return(A);
                }
            }
            if (B.Rate == 0)
            {
                return(B);
            }

            double     qPoint = q.GetMean();
            double     r      = product.Rate;
            double     shape2 = GammaFromShapeAndRateOp_Slow.AddShapesMinus1(product.Shape, A.Shape) + (1 - A.Power);
            GammaPower productMarginal;
            // threshold ensures 6/qPoint^4 does not overflow
            double threshold = Math.Sqrt(Math.Sqrt(6 / double.MaxValue));

            if (shape2 > 2 && result.Power < 0 && qPoint > threshold)
            {
                // Compute the moments of product^(-1/product.Power)
                // Here q = b^(1/b.Power)
                // E[a^(-1/a.Power) b^(-1/b.Power)] = E[(q r + a_r)/(shape2-1)/q]
                // var(a^(-1/a.Power) b^(-1/b.Power)) = E[(q r + a_r)^2/(shape2-1)/(shape2-2)/q^2] - E[a^(-1/a.Power) b^(-1/b.Power)]^2
                //          = (var((q r + a_r)/q) + E[(q r + a_r)/q]^2)/(shape2-1)/(shape2-2) - E[(q r + a_r)/q]^2/(shape2-1)^2
                //          = var((q r + a_r)/q)/(shape2-1)/(shape2-2) + E[(q r + a_r)/(shape2-1)/q]^2/(shape2-2)
                double iqMean, iqVariance;
                GetIQMoments(product, A, q, qPoint, out iqMean, out iqVariance);
                double ipMean     = (r + A.Rate * iqMean) / (shape2 - 1);
                double ipVariance = (iqVariance * A.Rate * A.Rate / (shape2 - 1) + ipMean * ipMean) / (shape2 - 2);
                // TODO: use ipVarianceOverMeanSquared
                GammaPower ipMarginal = GammaPower.FromMeanAndVariance(ipMean, ipVariance, -1);
                if (ipMarginal.IsUniform())
                {
                    return(GammaPower.Uniform(result.Power));
                }
                else
                {
                    productMarginal = GammaPower.FromShapeAndRate(ipMarginal.Shape, ipMarginal.Rate, result.Power);
                }
                bool check = false;
                if (check)
                {
                    // Importance sampling
                    MeanVarianceAccumulator mvaInvQ       = new MeanVarianceAccumulator();
                    MeanVarianceAccumulator mvaInvProduct = new MeanVarianceAccumulator();
                    Gamma  qPrior = Gamma.FromShapeAndRate(B.Shape, B.Rate);
                    double shift  = (product.Shape - product.Power) * Math.Log(qPoint) - shape2 * Math.Log(A.Rate + qPoint * r) + qPrior.GetLogProb(qPoint) - q.GetLogProb(qPoint);
                    for (int i = 0; i < 1000000; i++)
                    {
                        double qSample = q.Sample();
                        // logf = (y_s-y_p)*log(b) - (s+y_s-pa)*log(r + b*y_r)
                        double logf   = (product.Shape - product.Power) * Math.Log(qSample) - shape2 * Math.Log(A.Rate + qSample * r) + qPrior.GetLogProb(qSample) - q.GetLogProb(qSample);
                        double weight = Math.Exp(logf - shift);
                        mvaInvQ.Add(1 / qSample, weight);
                        double invProduct = (r + A.Rate / qSample) / (shape2 - 1);
                        mvaInvProduct.Add(invProduct, weight);
                    }
                    Trace.WriteLine($"invQ = {mvaInvQ}, {iqMean}, {iqVariance}");
                    Trace.WriteLine($"invProduct = {mvaInvProduct}");
                    Trace.WriteLine($"invA = {mvaInvProduct.Variance * (shape2 - 1) / (shape2 - 2) + mvaInvProduct.Mean * mvaInvProduct.Mean / (shape2 - 2)}, {ipMean}, {ipVariance}");
                    Trace.WriteLine($"productMarginal = {productMarginal}");
                }
            }
            else
            {
                // Compute the moments of y = product^(1/product.Power)
                // yMean = E[shape2*b/(b y_r + a_r)]
                // yVariance = E[shape2*(shape2+1)*b^2/(b y_r + a_r)^2] - yMean^2
                //           = var(shape2*b/(b y_r + a_r)) + E[shape2*b^2/(b y_r + a_r)^2]
                //           = shape2^2*var(b/(b y_r + a_r)) + shape2*(var(b/(b y_r + a_r)) + (yMean/shape2)^2)
                // Let g = b/(b y_r + a_r)
                double   denom        = qPoint * r + A.Rate;
                double   denom2       = denom * denom;
                double   rOverDenom   = r / denom;
                double[] gDerivatives = (denom == 0)
                    ? new double[] { 0, 0, 0, 0 }
                    : new double[] { qPoint / denom, A.Rate / denom2, -2 * A.Rate / denom2 * rOverDenom, 6 * A.Rate / denom2 * rOverDenom * rOverDenom };
                double gMean, gVariance;
                GaussianOp_Laplace.LaplaceMoments(q, gDerivatives, dlogfs(qPoint, product, A), out gMean, out gVariance);
                double yMean     = shape2 * gMean;
                double yVariance = shape2 * shape2 * gVariance + shape2 * (gVariance + gMean * gMean);
                productMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(yMean, yVariance), result.Power);
            }

            result.SetToRatio(productMarginal, product, GammaProductOp_Laplace.ForceProper);
            if (double.IsNaN(result.Shape) || double.IsNaN(result.Rate))
            {
                throw new InferRuntimeException("result is nan");
            }
            return(result);
        }
예제 #7
0
        public static GammaPower GammaPowerFromDifferentPower(GammaPower message, double newPower)
        {
            if (message.Power == newPower)
            {
                return(message);                           // same as below, but faster
            }
            if (message.IsUniform())
            {
                return(GammaPower.Uniform(newPower));
            }
            // Making two hops ensures that the desired mean powers are finite.
            if (message.Power > 0 && newPower < 0 && newPower != -1)
            {
                return(GammaPowerFromDifferentPower(GammaPowerFromDifferentPower(message, -1), newPower));
            }
            if (message.Power < 0 && newPower > 0 && newPower != 1)
            {
                return(GammaPowerFromDifferentPower(GammaPowerFromDifferentPower(message, 1), newPower));
            }
            // Project the message onto the desired power
            if (newPower == 1 || newPower == -1 || newPower == 2)
            {
                message.GetMeanAndVariance(out double mean, out double variance);
                if (!double.IsPositiveInfinity(mean))
                {
                    return(GammaPower.FromMeanAndVariance(mean, variance, newPower));
                }
                // Fall through
            }
            bool useMean = false;

            if (useMean)
            {
                // Constraints:
                // mean = Gamma(Shape + newPower)/Gamma(Shape)/Rate^newPower =approx (Shape/Rate)^newPower
                // mean2 = Gamma(Shape + 2*newPower)/Gamma(Shape)/Rate^(2*newPower) =approx ((Shape + newPower)/Rate)^newPower * (Shape/Rate)^newPower
                // mean2/mean^2 = Gamma(Shape + 2*newPower)*Gamma(Shape)/Gamma(Shape + newPower)^2 =approx ((Shape + newPower)/Shape)^newPower
                // Shape =approx newPower/((mean2/mean^2)^(1/newPower) - 1)
                // Rate = Shape/mean^(1/newPower)
                message.GetMeanAndVariance(out double mean, out double variance);
                double meanp  = System.Math.Pow(mean, 1 / newPower);
                double mean2p = System.Math.Pow(variance + mean * mean, 1 / newPower);
                double shape  = newPower / (mean2p / meanp / meanp - 1);
                if (double.IsInfinity(shape))
                {
                    return(GammaPower.PointMass(mean, newPower));
                }
                double rate = shape / meanp;
                return(GammaPower.FromShapeAndRate(shape, rate, newPower));
            }
            else
            {
                // Compute the mean and variance of x^1/newPower
                double mean     = message.GetMeanPower(1 / newPower);
                double mean2    = message.GetMeanPower(2 / newPower);
                double variance = System.Math.Max(0, mean2 - mean * mean);
                if (double.IsPositiveInfinity(mean * mean))
                {
                    variance = mean;
                }
                return(GammaPower.FromGamma(Gamma.FromMeanAndVariance(mean, variance), newPower));
            }
        }
예제 #8
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaPowerProductOp_Laplace"]/message_doc[@name="AAverageConditional(GammaPower, GammaPower, GammaPower, Gamma, GammaPower)"]/*'/>
        public static GammaPower AAverageConditional([SkipIfUniform] GammaPower product, GammaPower A, [SkipIfUniform] GammaPower B, Gamma q, GammaPower result)
        {
            if (B.Shape < A.Shape)
            {
                return(BAverageConditional(product, B, A, q, result));
            }
            if (B.IsPointMass)
            {
                return(GammaProductOp.AAverageConditional(product, B.Point, result));
            }
            if (A.IsPointMass)
            {
                return(GammaPower.Uniform(A.Power)); // TODO
            }
            if (product.IsUniform())
            {
                return(product);
            }
            double     qPoint = q.GetMean();
            GammaPower aMarginal;

            if (product.IsPointMass)
            {
                // Z = int Ga(y/q; s, r)/q Ga(q; q_s, q_r) dq
                // E[a] = E[product/q]
                // E[a^2] = E[product^2/q^2]
                // aVariance = E[a^2] - aMean^2
                double productPoint = product.Point;
                if (productPoint == 0)
                {
                    aMarginal = GammaPower.PointMass(0, result.Power);
                }
                else
                {
                    double   iq = 1 / qPoint;
                    double   iq2 = iq * iq;
                    double[] iqDerivatives = new double[] { iq, -iq2, 2 * iq2 * iq, -6 * iq2 * iq2 };
                    double   iqMean, iqVariance;
                    GaussianOp_Laplace.LaplaceMoments(q, iqDerivatives, dlogfs(qPoint, product, A), out iqMean, out iqVariance);
                    double aMean     = productPoint * iqMean;
                    double aVariance = productPoint * productPoint * iqVariance;
                    aMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(aMean, aVariance), result.Power);
                }
            }
            else
            {
                if (double.IsPositiveInfinity(product.Rate))
                {
                    return(GammaPower.PointMass(0, result.Power));
                }
                if (A.Power != product.Power)
                {
                    throw new NotSupportedException($"A.Power ({A.Power}) != product.Power ({product.Power})");
                }
                if (B.Power != product.Power)
                {
                    throw new NotSupportedException($"B.Power ({B.Power}) != product.Power ({product.Power})");
                }
                double r      = product.Rate;
                double r2     = r * r;
                double g      = 1 / (qPoint * r + A.Rate);
                double g2     = g * g;
                double shape2 = GammaFromShapeAndRateOp_Slow.AddShapesMinus1(product.Shape, A.Shape) + (1 - A.Power);
                // From above:
                // a^(y_s-pa + a_s-1) exp(-(y_r b + a_r)*a)
                if (shape2 > 2)
                {
                    // Compute the moments of a^(-1/a.Power)
                    // Here q = b^(1/b.Power)
                    // E[a^(-1/a.Power)] = E[(q r + a_r)/(shape2-1)]
                    // var(a^(-1/a.Power)) = E[(q r + a_r)^2/(shape2-1)/(shape2-2)] - E[a^(-1/a.Power)]^2
                    //          = (var(q r + a_r) + E[(q r + a_r)]^2)/(shape2-1)/(shape2-2) - E[(q r + a_r)]^2/(shape2-1)^2
                    //          = var(q r + a_r)/(shape2-1)/(shape2-2) + E[(q r + a_r)/(shape2-1)]^2/(shape2-2)
                    double[] qDerivatives = new double[] { qPoint, 1, 0, 0 };
                    double   qMean, qVariance;
                    GaussianOp_Laplace.LaplaceMoments(q, qDerivatives, dlogfs(qPoint, product, A), out qMean, out qVariance);
                    double     iaMean     = (qMean * r + A.Rate) / (shape2 - 1);
                    double     iaVariance = (qVariance * r2 / (shape2 - 1) + iaMean * iaMean) / (shape2 - 2);
                    GammaPower iaMarginal = GammaPower.FromMeanAndVariance(iaMean, iaVariance, -1);
                    if (iaMarginal.IsUniform())
                    {
                        if (result.Power > 0)
                        {
                            return(GammaPower.PointMass(0, result.Power));
                        }
                        else
                        {
                            return(GammaPower.Uniform(result.Power));
                        }
                    }
                    else
                    {
                        aMarginal = GammaPower.FromShapeAndRate(iaMarginal.Shape, iaMarginal.Rate, result.Power);
                    }
                    bool check = false;
                    if (check)
                    {
                        // Importance sampling
                        MeanVarianceAccumulator mvaB    = new MeanVarianceAccumulator();
                        MeanVarianceAccumulator mvaInvA = new MeanVarianceAccumulator();
                        Gamma bPrior = Gamma.FromShapeAndRate(B.Shape, B.Rate);
                        q = bPrior;
                        double shift = (product.Shape - product.Power) * Math.Log(qPoint) - shape2 * Math.Log(A.Rate + qPoint * r) + bPrior.GetLogProb(qPoint) - q.GetLogProb(qPoint);
                        for (int i = 0; i < 1000000; i++)
                        {
                            double bSample = q.Sample();
                            // logf = (y_s-y_p)*log(b) - (s+y_s-pa)*log(r + b*y_r)
                            double logf   = (product.Shape - product.Power) * Math.Log(bSample) - shape2 * Math.Log(A.Rate + bSample * r) + bPrior.GetLogProb(bSample) - q.GetLogProb(bSample);
                            double weight = Math.Exp(logf - shift);
                            mvaB.Add(bSample, weight);
                            double invA = (bSample * r + A.Rate) / (shape2 - 1);
                            mvaInvA.Add(invA, weight);
                        }
                        Trace.WriteLine($"b = {mvaB}, {qMean}, {qVariance}");
                        Trace.WriteLine($"invA = {mvaInvA} {mvaInvA.Variance * (shape2 - 1) / (shape2 - 2) + mvaInvA.Mean * mvaInvA.Mean / (shape2 - 2)}, {iaMean}, {iaVariance}");
                        Trace.WriteLine($"aMarginal = {aMarginal}");
                    }
                }
                else
                {
                    // Compute the moments of a^(1/a.Power)
                    // aMean = shape2/(b y_r + a_r)
                    // aVariance = E[shape2*(shape2+1)/(b y_r + a_r)^2] - aMean^2 = var(shape2/(b y_r + a_r)) + E[shape2/(b y_r + a_r)^2]
                    //           = shape2^2*var(1/(b y_r + a_r)) + shape2*(var(1/(b y_r + a_r)) + (aMean/shape2)^2)
                    double[] gDerivatives = new double[] { g, -r * g2, 2 * g2 * g * r2, -6 * g2 * g2 * r2 * r };
                    double   gMean, gVariance;
                    GaussianOp_Laplace.LaplaceMoments(q, gDerivatives, dlogfs(qPoint, product, A), out gMean, out gVariance);
                    double aMean     = shape2 * gMean;
                    double aVariance = shape2 * shape2 * gVariance + shape2 * (gVariance + gMean * gMean);
                    aMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(aMean, aVariance), result.Power);
                }
            }
            result.SetToRatio(aMarginal, A, GammaProductOp_Laplace.ForceProper);
            if (double.IsNaN(result.Shape) || double.IsNaN(result.Rate))
            {
                throw new InferRuntimeException("result is nan");
            }
            return(result);
        }
예제 #9
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="AAverageConditional(GammaPower, GammaPower)"]/*'/>
        public static GammaPower AAverageConditional2([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result)
        {
            if (sum.IsUniform() || b.IsUniform())
            {
                result.SetToUniform();
                return(result);
            }
            if (sum.IsPointMass && sum.Point == 0)
            {
                if (b.IsPointMass || a.Shape / a.Power <= b.Shape / b.Power)
                {
                    result.Point = 0;
                    return(result);
                }
                else
                {
                    result.SetToUniform();
                    return(result);
                }
            }
            if (sum.IsProper() && b.IsProper())
            {
                sum.GetMeanAndVariance(out double sumMean, out double sumVariance);
                b.GetMeanAndVariance(out double bMean, out double bVariance);
                double rMean     = sumMean - bMean;
                double rVariance = sumVariance + bVariance;
                double aVariance = a.GetVariance();
                if (aVariance >= rVariance)
                {
                    if (rMean < 0)
                    {
                        // If b cannot be less than x, then sum cannot be less than x.
                        double tailProbability = 0.1;
                        double bLowerBound     = b.GetQuantile(tailProbability);
                        // If sum cannot be more than x, then b cannot be more than x.
                        double sumUpperBound = sum.GetQuantile(1 - tailProbability);
                        if (sumUpperBound > bLowerBound)
                        {
                            // Compute the mean of the truncated distributions.
                            // sum is truncated to [bLowerBound, infinity]
                            sumMean = TruncatedGammaPowerGetMean(sum, bLowerBound, double.PositiveInfinity);
                            // b is truncated to [0, sumUpperBound]
                            bMean = TruncatedGammaPowerGetMean(b, 0, sumUpperBound);
                            rMean = sumMean - bMean;
                        }
                        else
                        {
                            //result.Point = 0;
                            //return result;
                        }
                    }
                    if (rMean > 0)
                    {
                        result = GammaPower.FromMeanAndVariance(rMean, rVariance, result.Power);
                        if (ForceProper && result.Power == 1 && result.Shape < 1)
                        {
                            result.Shape = 1;
                            // Set rate such that shape/rate = rMean
                            result.Rate = result.Shape / rMean;
                        }
                        return(result);
                    }
                }
            }
            // logZ = sum.GetLogAverageOf(toSum)
            // where toSum = SumAverageConditional(a, b)
            // Compute the derivatives wrt a.Rate to get the message to A.
            if (sum.Power == 1)
            {
                GetGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance);
                mean     += b.GetMean();
                variance += b.GetVariance();
                GetGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr);
                GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ);
                return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ));
            }
            else if (sum.Power == -1)
            {
                bool aHasInfiniteMean = (a.Power == -1 && a.Shape <= 1);
                bool bHasInfiniteMean = (b.Power == -1 && b.Shape <= 1);
                if (aHasInfiniteMean)
                {
                    if (bHasInfiniteMean)
                    {
                        if (a.Shape <= b.Shape)
                        {
                            return(sum);
                        }
                        else
                        {
                            result.SetToUniform();
                            return(result);
                        }
                    }
                    else
                    {
                        return(sum);
                    }
                }
                else if (bHasInfiniteMean)
                {
                    result.SetToUniform();
                    return(result);
                }

                GetInverseGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance);
                mean     += b.GetMean();
                variance += b.GetVariance();
                if (variance > double.MaxValue)
                {
                    return(GammaPower.Uniform(a.Power));                            //throw new NotSupportedException();
                }
                GetInverseGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr);
                if (sum.IsPointMass && sum.Point == 0)
                {
                    return(GammaPower.PointMass(0, a.Power));
                }
                GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ);
                return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ));
            }
            else
            {
                throw new NotImplementedException($"sum.Power == {sum.Power}");
            }
        }