예제 #1
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaProductOp"]/message_doc[@name="LogAverageFactor(double, GammaPower, double)"]/*'/>
        public static double LogAverageFactor(double product, GammaPower a, double b)
        {
            if (b == 0)
            {
                return((product == 0) ? 0.0 : double.NegativeInfinity);
            }
            GammaPower to_product = GammaProductOp.ProductAverageConditional(a, b);

            return(to_product.GetLogProb(product));
        }
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaPowerProductOp_Laplace"]/message_doc[@name="ProductAverageConditional(GammaPower, GammaPower, GammaPower, Gamma, GammaPower)"]/*'/>
        public static GammaPower ProductAverageConditional(GammaPower product, [Proper] GammaPower A, [SkipIfUniform] GammaPower B, Gamma q, GammaPower result)
        {
            if (B.Shape < A.Shape)
            {
                return(ProductAverageConditional(product, B, A, q, result));
            }
            if (B.IsPointMass)
            {
                return(GammaProductOp.ProductAverageConditional(A, B.Point));
            }
            if (B.IsUniform())
            {
                return(GammaPower.Uniform(result.Power));
            }
            if (A.IsPointMass)
            {
                return(GammaProductOp.ProductAverageConditional(A.Point, B));
            }
            if (product.IsPointMass)
            {
                return(GammaPower.Uniform(result.Power)); // TODO
            }
            if (A.Power != product.Power)
            {
                throw new NotSupportedException($"A.Power ({A.Power}) != product.Power ({product.Power})");
            }
            if (B.Power != product.Power)
            {
                throw new NotSupportedException($"B.Power ({B.Power}) != product.Power ({product.Power})");
            }
            if (A.Rate == 0)
            {
                if (B.Rate == 0)
                {
                    return(GammaPower.FromShapeAndRate(Math.Min(A.Shape, B.Shape), 0, result.Power));
                }
                else
                {
                    return(A);
                }
            }
            if (B.Rate == 0)
            {
                return(B);
            }

            double     qPoint = q.GetMean();
            double     r      = product.Rate;
            double     shape2 = GammaFromShapeAndRateOp_Slow.AddShapesMinus1(product.Shape, A.Shape) + (1 - A.Power);
            GammaPower productMarginal;
            // threshold ensures 6/qPoint^4 does not overflow
            double threshold = Math.Sqrt(Math.Sqrt(6 / double.MaxValue));

            if (shape2 > 2 && result.Power < 0 && qPoint > threshold)
            {
                // Compute the moments of product^(-1/product.Power)
                // Here q = b^(1/b.Power)
                // E[a^(-1/a.Power) b^(-1/b.Power)] = E[(q r + a_r)/(shape2-1)/q]
                // var(a^(-1/a.Power) b^(-1/b.Power)) = E[(q r + a_r)^2/(shape2-1)/(shape2-2)/q^2] - E[a^(-1/a.Power) b^(-1/b.Power)]^2
                //          = (var((q r + a_r)/q) + E[(q r + a_r)/q]^2)/(shape2-1)/(shape2-2) - E[(q r + a_r)/q]^2/(shape2-1)^2
                //          = var((q r + a_r)/q)/(shape2-1)/(shape2-2) + E[(q r + a_r)/(shape2-1)/q]^2/(shape2-2)
                double iqMean, iqVariance;
                GetIQMoments(product, A, q, qPoint, out iqMean, out iqVariance);
                double ipMean     = (r + A.Rate * iqMean) / (shape2 - 1);
                double ipVariance = (iqVariance * A.Rate * A.Rate / (shape2 - 1) + ipMean * ipMean) / (shape2 - 2);
                // TODO: use ipVarianceOverMeanSquared
                GammaPower ipMarginal = GammaPower.FromMeanAndVariance(ipMean, ipVariance, -1);
                if (ipMarginal.IsUniform())
                {
                    return(GammaPower.Uniform(result.Power));
                }
                else
                {
                    productMarginal = GammaPower.FromShapeAndRate(ipMarginal.Shape, ipMarginal.Rate, result.Power);
                }
                bool check = false;
                if (check)
                {
                    // Importance sampling
                    MeanVarianceAccumulator mvaInvQ       = new MeanVarianceAccumulator();
                    MeanVarianceAccumulator mvaInvProduct = new MeanVarianceAccumulator();
                    Gamma  qPrior = Gamma.FromShapeAndRate(B.Shape, B.Rate);
                    double shift  = (product.Shape - product.Power) * Math.Log(qPoint) - shape2 * Math.Log(A.Rate + qPoint * r) + qPrior.GetLogProb(qPoint) - q.GetLogProb(qPoint);
                    for (int i = 0; i < 1000000; i++)
                    {
                        double qSample = q.Sample();
                        // logf = (y_s-y_p)*log(b) - (s+y_s-pa)*log(r + b*y_r)
                        double logf   = (product.Shape - product.Power) * Math.Log(qSample) - shape2 * Math.Log(A.Rate + qSample * r) + qPrior.GetLogProb(qSample) - q.GetLogProb(qSample);
                        double weight = Math.Exp(logf - shift);
                        mvaInvQ.Add(1 / qSample, weight);
                        double invProduct = (r + A.Rate / qSample) / (shape2 - 1);
                        mvaInvProduct.Add(invProduct, weight);
                    }
                    Trace.WriteLine($"invQ = {mvaInvQ}, {iqMean}, {iqVariance}");
                    Trace.WriteLine($"invProduct = {mvaInvProduct}");
                    Trace.WriteLine($"invA = {mvaInvProduct.Variance * (shape2 - 1) / (shape2 - 2) + mvaInvProduct.Mean * mvaInvProduct.Mean / (shape2 - 2)}, {ipMean}, {ipVariance}");
                    Trace.WriteLine($"productMarginal = {productMarginal}");
                }
            }
            else
            {
                // Compute the moments of y = product^(1/product.Power)
                // yMean = E[shape2*b/(b y_r + a_r)]
                // yVariance = E[shape2*(shape2+1)*b^2/(b y_r + a_r)^2] - yMean^2
                //           = var(shape2*b/(b y_r + a_r)) + E[shape2*b^2/(b y_r + a_r)^2]
                //           = shape2^2*var(b/(b y_r + a_r)) + shape2*(var(b/(b y_r + a_r)) + (yMean/shape2)^2)
                // Let g = b/(b y_r + a_r)
                double   denom        = qPoint * r + A.Rate;
                double   denom2       = denom * denom;
                double   rOverDenom   = r / denom;
                double[] gDerivatives = (denom == 0)
                    ? new double[] { 0, 0, 0, 0 }
                    : new double[] { qPoint / denom, A.Rate / denom2, -2 * A.Rate / denom2 * rOverDenom, 6 * A.Rate / denom2 * rOverDenom * rOverDenom };
                double gMean, gVariance;
                GaussianOp_Laplace.LaplaceMoments(q, gDerivatives, dlogfs(qPoint, product, A), out gMean, out gVariance);
                double yMean     = shape2 * gMean;
                double yVariance = shape2 * shape2 * gVariance + shape2 * (gVariance + gMean * gMean);
                productMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(yMean, yVariance), result.Power);
            }

            result.SetToRatio(productMarginal, product, GammaProductOp_Laplace.ForceProper);
            if (double.IsNaN(result.Shape) || double.IsNaN(result.Rate))
            {
                throw new InferRuntimeException("result is nan");
            }
            return(result);
        }
예제 #3
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaProductOp"]/message_doc[@name="LogAverageFactor(GammaPower, GammaPower, double)"]/*'/>
        public static double LogAverageFactor(GammaPower product, GammaPower a, double b)
        {
            GammaPower to_product = GammaProductOp.ProductAverageConditional(a, b);

            return(to_product.GetLogAverageOf(product));
        }