public void ExpOpGammaPower_PointExp() { double power = -1; double vd = 1e-4; vd = 1e-3; Gaussian d = new Gaussian(0, vd); Gaussian uniform = Gaussian.Uniform(); GammaPower expPoint = GammaPower.PointMass(2, power); GammaPower to_exp_point = ExpOp.ExpAverageConditional(expPoint, d, uniform); Gaussian to_d_point = ExpOp.DAverageConditional(expPoint, d, uniform); double to_exp_oldError = double.PositiveInfinity; double to_d_oldError = double.PositiveInfinity; for (int i = 0; i < 100; i++) { double ve = System.Math.Pow(10, -i); GammaPower exp = GammaPower.FromMeanAndVariance(2, ve, power); GammaPower to_exp = ExpOp.ExpAverageConditional(exp, d, uniform); Gaussian to_d = ExpOp.DAverageConditional(exp, d, uniform); double to_exp_error = to_exp.MaxDiff(to_exp_point); double to_d_error = System.Math.Abs(to_d.GetMean() - to_d_point.GetMean()); Trace.WriteLine($"ve={ve}: to_exp={to_exp} error={to_exp_error} to_d={to_d} error={to_d_error}"); Assert.True(to_exp_error <= to_exp_oldError); to_exp_oldError = to_exp_error; Assert.True(to_d_error <= to_d_oldError); to_d_oldError = to_d_error; } }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="SumAverageConditional(GammaPower, GammaPower)"]/*'/> public static GammaPower SumAverageConditional([SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result) { a.GetMeanAndVariance(out double aMean, out double aVariance); b.GetMeanAndVariance(out double bMean, out double bVariance); double mean = aMean + bMean; double variance = aVariance + bVariance; return(GammaPower.FromMeanAndVariance(mean, variance, result.Power)); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaPowerProductOp_Laplace"]/message_doc[@name="BAverageConditional(GammaPower, GammaPower, GammaPower, Gamma, GammaPower)"]/*'/> public static GammaPower BAverageConditional([SkipIfUniform] GammaPower product, [Proper] GammaPower A, [Proper] GammaPower B, Gamma q, GammaPower result) { if (B.Shape < A.Shape) { return(AAverageConditional(product, B, A, q, result)); } if (A.IsPointMass) { return(GammaProductOp.BAverageConditional(product, A.Point, result)); } if (B.IsPointMass) { return(GammaPower.Uniform(result.Power)); // TODO } if (product.IsUniform()) { return(product); } if (q.IsUniform()) { q = Q(product, A, B); } double qPoint = q.GetMean(); GammaPower bMarginal; // threshold ensures 6/qPoint^4 does not overflow double threshold = Math.Sqrt(Math.Sqrt(6 / double.MaxValue)); if (result.Power < 0 && qPoint > threshold) { double iqMean, iqVariance; GetIQMoments(product, A, q, qPoint, out iqMean, out iqVariance); GammaPower iqMarginal = GammaPower.FromMeanAndVariance(iqMean, iqVariance, -1); bMarginal = GammaPower.FromShapeAndRate(iqMarginal.Shape, iqMarginal.Rate, result.Power); } else { // B.Shape >= A.Shape therefore Q is the approximate distribution of B^(1/B.Power). // We compute the approximate moments of q = b^(1/b.Power) to get a Gamma distribution and then raise to B.Power. double qMean, qVariance; GetQMoments(product, A, q, qPoint, out qMean, out qVariance); bMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(qMean, qVariance), result.Power); } result.SetToRatio(bMarginal, B, GammaProductOp_Laplace.ForceProper); return(result); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="AAverageConditional(GammaPower, GammaPower)"]/*'/> public static GammaPower AAverageConditional([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result) { if (sum.IsUniform()) { return(sum); } sum.GetMeanAndVariance(out double sumMean, out double sumVariance); b.GetMeanAndVariance(out double bMean, out double bVariance); double rMean = Math.Max(0, sumMean - bMean); double rVariance = sumVariance + bVariance; double aVariance = a.GetVariance(); if (rVariance > aVariance) { if (sum.Power == 1) { GetGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance); mean += b.GetMean(); variance += b.GetVariance(); GetGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr); GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ); return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ)); } else if (sum.Power == -1) { GetInverseGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance); mean += b.GetMean(); variance += b.GetVariance(); if (variance > double.MaxValue) { return(GammaPower.Uniform(a.Power)); //throw new NotSupportedException(); } GetInverseGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr); if (sum.IsPointMass && sum.Point == 0) { return(GammaPower.PointMass(0, a.Power)); } GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ); return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ)); } } return(GammaPower.FromMeanAndVariance(rMean, rVariance, result.Power)); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaVmpOp"]/message_doc[@name="SumAverageLogarithm(GammaPower, GammaPower, GammaPower)"]/*'/> public static GammaPower SumAverageLogarithm([SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result) { if (a.IsUniform() || b.IsUniform()) { result.SetToUniform(); return(result); } if (result.Power == -1) { bool aHasInfiniteMean = (a.Power == -1 && a.Shape <= 1); bool bHasInfiniteMean = (b.Power == -1 && b.Shape <= 1); if (aHasInfiniteMean) { if (bHasInfiniteMean) { return(InvGammaFromShapeAndMeanInverse(Math.Min(a.Shape, b.Shape), MeanInverseOfSum(a, b))); } else { return(InvGammaFromShapeAndMeanInverse(a.Shape, MeanInverseOfSum(a, b))); } } else if (bHasInfiniteMean) { return(InvGammaFromShapeAndMeanInverse(b.Shape, MeanInverseOfSum(a, b))); } } a.GetMeanAndVariance(out double aMean, out double aVariance); b.GetMeanAndVariance(out double bMean, out double bVariance); double mean = aMean + bMean; double variance = aVariance + bVariance; if (result.Power == -1 && variance > double.MaxValue && false) { // mean is finite return(InvGammaFromMeanAndMeanInverse(mean, MeanInverseOfSum(a, b))); } return(GammaPower.FromMeanAndVariance(mean, variance, result.Power)); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaPowerProductOp_Laplace"]/message_doc[@name="ProductAverageConditional(GammaPower, GammaPower, GammaPower, Gamma, GammaPower)"]/*'/> public static GammaPower ProductAverageConditional(GammaPower product, [Proper] GammaPower A, [SkipIfUniform] GammaPower B, Gamma q, GammaPower result) { if (B.Shape < A.Shape) { return(ProductAverageConditional(product, B, A, q, result)); } if (B.IsPointMass) { return(GammaProductOp.ProductAverageConditional(A, B.Point)); } if (B.IsUniform()) { return(GammaPower.Uniform(result.Power)); } if (A.IsPointMass) { return(GammaProductOp.ProductAverageConditional(A.Point, B)); } if (product.IsPointMass) { return(GammaPower.Uniform(result.Power)); // TODO } if (A.Power != product.Power) { throw new NotSupportedException($"A.Power ({A.Power}) != product.Power ({product.Power})"); } if (B.Power != product.Power) { throw new NotSupportedException($"B.Power ({B.Power}) != product.Power ({product.Power})"); } if (A.Rate == 0) { if (B.Rate == 0) { return(GammaPower.FromShapeAndRate(Math.Min(A.Shape, B.Shape), 0, result.Power)); } else { return(A); } } if (B.Rate == 0) { return(B); } double qPoint = q.GetMean(); double r = product.Rate; double shape2 = GammaFromShapeAndRateOp_Slow.AddShapesMinus1(product.Shape, A.Shape) + (1 - A.Power); GammaPower productMarginal; // threshold ensures 6/qPoint^4 does not overflow double threshold = Math.Sqrt(Math.Sqrt(6 / double.MaxValue)); if (shape2 > 2 && result.Power < 0 && qPoint > threshold) { // Compute the moments of product^(-1/product.Power) // Here q = b^(1/b.Power) // E[a^(-1/a.Power) b^(-1/b.Power)] = E[(q r + a_r)/(shape2-1)/q] // var(a^(-1/a.Power) b^(-1/b.Power)) = E[(q r + a_r)^2/(shape2-1)/(shape2-2)/q^2] - E[a^(-1/a.Power) b^(-1/b.Power)]^2 // = (var((q r + a_r)/q) + E[(q r + a_r)/q]^2)/(shape2-1)/(shape2-2) - E[(q r + a_r)/q]^2/(shape2-1)^2 // = var((q r + a_r)/q)/(shape2-1)/(shape2-2) + E[(q r + a_r)/(shape2-1)/q]^2/(shape2-2) double iqMean, iqVariance; GetIQMoments(product, A, q, qPoint, out iqMean, out iqVariance); double ipMean = (r + A.Rate * iqMean) / (shape2 - 1); double ipVariance = (iqVariance * A.Rate * A.Rate / (shape2 - 1) + ipMean * ipMean) / (shape2 - 2); // TODO: use ipVarianceOverMeanSquared GammaPower ipMarginal = GammaPower.FromMeanAndVariance(ipMean, ipVariance, -1); if (ipMarginal.IsUniform()) { return(GammaPower.Uniform(result.Power)); } else { productMarginal = GammaPower.FromShapeAndRate(ipMarginal.Shape, ipMarginal.Rate, result.Power); } bool check = false; if (check) { // Importance sampling MeanVarianceAccumulator mvaInvQ = new MeanVarianceAccumulator(); MeanVarianceAccumulator mvaInvProduct = new MeanVarianceAccumulator(); Gamma qPrior = Gamma.FromShapeAndRate(B.Shape, B.Rate); double shift = (product.Shape - product.Power) * Math.Log(qPoint) - shape2 * Math.Log(A.Rate + qPoint * r) + qPrior.GetLogProb(qPoint) - q.GetLogProb(qPoint); for (int i = 0; i < 1000000; i++) { double qSample = q.Sample(); // logf = (y_s-y_p)*log(b) - (s+y_s-pa)*log(r + b*y_r) double logf = (product.Shape - product.Power) * Math.Log(qSample) - shape2 * Math.Log(A.Rate + qSample * r) + qPrior.GetLogProb(qSample) - q.GetLogProb(qSample); double weight = Math.Exp(logf - shift); mvaInvQ.Add(1 / qSample, weight); double invProduct = (r + A.Rate / qSample) / (shape2 - 1); mvaInvProduct.Add(invProduct, weight); } Trace.WriteLine($"invQ = {mvaInvQ}, {iqMean}, {iqVariance}"); Trace.WriteLine($"invProduct = {mvaInvProduct}"); Trace.WriteLine($"invA = {mvaInvProduct.Variance * (shape2 - 1) / (shape2 - 2) + mvaInvProduct.Mean * mvaInvProduct.Mean / (shape2 - 2)}, {ipMean}, {ipVariance}"); Trace.WriteLine($"productMarginal = {productMarginal}"); } } else { // Compute the moments of y = product^(1/product.Power) // yMean = E[shape2*b/(b y_r + a_r)] // yVariance = E[shape2*(shape2+1)*b^2/(b y_r + a_r)^2] - yMean^2 // = var(shape2*b/(b y_r + a_r)) + E[shape2*b^2/(b y_r + a_r)^2] // = shape2^2*var(b/(b y_r + a_r)) + shape2*(var(b/(b y_r + a_r)) + (yMean/shape2)^2) // Let g = b/(b y_r + a_r) double denom = qPoint * r + A.Rate; double denom2 = denom * denom; double rOverDenom = r / denom; double[] gDerivatives = (denom == 0) ? new double[] { 0, 0, 0, 0 } : new double[] { qPoint / denom, A.Rate / denom2, -2 * A.Rate / denom2 * rOverDenom, 6 * A.Rate / denom2 * rOverDenom * rOverDenom }; double gMean, gVariance; GaussianOp_Laplace.LaplaceMoments(q, gDerivatives, dlogfs(qPoint, product, A), out gMean, out gVariance); double yMean = shape2 * gMean; double yVariance = shape2 * shape2 * gVariance + shape2 * (gVariance + gMean * gMean); productMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(yMean, yVariance), result.Power); } result.SetToRatio(productMarginal, product, GammaProductOp_Laplace.ForceProper); if (double.IsNaN(result.Shape) || double.IsNaN(result.Rate)) { throw new InferRuntimeException("result is nan"); } return(result); }
public static GammaPower GammaPowerFromDifferentPower(GammaPower message, double newPower) { if (message.Power == newPower) { return(message); // same as below, but faster } if (message.IsUniform()) { return(GammaPower.Uniform(newPower)); } // Making two hops ensures that the desired mean powers are finite. if (message.Power > 0 && newPower < 0 && newPower != -1) { return(GammaPowerFromDifferentPower(GammaPowerFromDifferentPower(message, -1), newPower)); } if (message.Power < 0 && newPower > 0 && newPower != 1) { return(GammaPowerFromDifferentPower(GammaPowerFromDifferentPower(message, 1), newPower)); } // Project the message onto the desired power if (newPower == 1 || newPower == -1 || newPower == 2) { message.GetMeanAndVariance(out double mean, out double variance); if (!double.IsPositiveInfinity(mean)) { return(GammaPower.FromMeanAndVariance(mean, variance, newPower)); } // Fall through } bool useMean = false; if (useMean) { // Constraints: // mean = Gamma(Shape + newPower)/Gamma(Shape)/Rate^newPower =approx (Shape/Rate)^newPower // mean2 = Gamma(Shape + 2*newPower)/Gamma(Shape)/Rate^(2*newPower) =approx ((Shape + newPower)/Rate)^newPower * (Shape/Rate)^newPower // mean2/mean^2 = Gamma(Shape + 2*newPower)*Gamma(Shape)/Gamma(Shape + newPower)^2 =approx ((Shape + newPower)/Shape)^newPower // Shape =approx newPower/((mean2/mean^2)^(1/newPower) - 1) // Rate = Shape/mean^(1/newPower) message.GetMeanAndVariance(out double mean, out double variance); double meanp = System.Math.Pow(mean, 1 / newPower); double mean2p = System.Math.Pow(variance + mean * mean, 1 / newPower); double shape = newPower / (mean2p / meanp / meanp - 1); if (double.IsInfinity(shape)) { return(GammaPower.PointMass(mean, newPower)); } double rate = shape / meanp; return(GammaPower.FromShapeAndRate(shape, rate, newPower)); } else { // Compute the mean and variance of x^1/newPower double mean = message.GetMeanPower(1 / newPower); double mean2 = message.GetMeanPower(2 / newPower); double variance = System.Math.Max(0, mean2 - mean * mean); if (double.IsPositiveInfinity(mean * mean)) { variance = mean; } return(GammaPower.FromGamma(Gamma.FromMeanAndVariance(mean, variance), newPower)); } }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="GammaPowerProductOp_Laplace"]/message_doc[@name="AAverageConditional(GammaPower, GammaPower, GammaPower, Gamma, GammaPower)"]/*'/> public static GammaPower AAverageConditional([SkipIfUniform] GammaPower product, GammaPower A, [SkipIfUniform] GammaPower B, Gamma q, GammaPower result) { if (B.Shape < A.Shape) { return(BAverageConditional(product, B, A, q, result)); } if (B.IsPointMass) { return(GammaProductOp.AAverageConditional(product, B.Point, result)); } if (A.IsPointMass) { return(GammaPower.Uniform(A.Power)); // TODO } if (product.IsUniform()) { return(product); } double qPoint = q.GetMean(); GammaPower aMarginal; if (product.IsPointMass) { // Z = int Ga(y/q; s, r)/q Ga(q; q_s, q_r) dq // E[a] = E[product/q] // E[a^2] = E[product^2/q^2] // aVariance = E[a^2] - aMean^2 double productPoint = product.Point; if (productPoint == 0) { aMarginal = GammaPower.PointMass(0, result.Power); } else { double iq = 1 / qPoint; double iq2 = iq * iq; double[] iqDerivatives = new double[] { iq, -iq2, 2 * iq2 * iq, -6 * iq2 * iq2 }; double iqMean, iqVariance; GaussianOp_Laplace.LaplaceMoments(q, iqDerivatives, dlogfs(qPoint, product, A), out iqMean, out iqVariance); double aMean = productPoint * iqMean; double aVariance = productPoint * productPoint * iqVariance; aMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(aMean, aVariance), result.Power); } } else { if (double.IsPositiveInfinity(product.Rate)) { return(GammaPower.PointMass(0, result.Power)); } if (A.Power != product.Power) { throw new NotSupportedException($"A.Power ({A.Power}) != product.Power ({product.Power})"); } if (B.Power != product.Power) { throw new NotSupportedException($"B.Power ({B.Power}) != product.Power ({product.Power})"); } double r = product.Rate; double r2 = r * r; double g = 1 / (qPoint * r + A.Rate); double g2 = g * g; double shape2 = GammaFromShapeAndRateOp_Slow.AddShapesMinus1(product.Shape, A.Shape) + (1 - A.Power); // From above: // a^(y_s-pa + a_s-1) exp(-(y_r b + a_r)*a) if (shape2 > 2) { // Compute the moments of a^(-1/a.Power) // Here q = b^(1/b.Power) // E[a^(-1/a.Power)] = E[(q r + a_r)/(shape2-1)] // var(a^(-1/a.Power)) = E[(q r + a_r)^2/(shape2-1)/(shape2-2)] - E[a^(-1/a.Power)]^2 // = (var(q r + a_r) + E[(q r + a_r)]^2)/(shape2-1)/(shape2-2) - E[(q r + a_r)]^2/(shape2-1)^2 // = var(q r + a_r)/(shape2-1)/(shape2-2) + E[(q r + a_r)/(shape2-1)]^2/(shape2-2) double[] qDerivatives = new double[] { qPoint, 1, 0, 0 }; double qMean, qVariance; GaussianOp_Laplace.LaplaceMoments(q, qDerivatives, dlogfs(qPoint, product, A), out qMean, out qVariance); double iaMean = (qMean * r + A.Rate) / (shape2 - 1); double iaVariance = (qVariance * r2 / (shape2 - 1) + iaMean * iaMean) / (shape2 - 2); GammaPower iaMarginal = GammaPower.FromMeanAndVariance(iaMean, iaVariance, -1); if (iaMarginal.IsUniform()) { if (result.Power > 0) { return(GammaPower.PointMass(0, result.Power)); } else { return(GammaPower.Uniform(result.Power)); } } else { aMarginal = GammaPower.FromShapeAndRate(iaMarginal.Shape, iaMarginal.Rate, result.Power); } bool check = false; if (check) { // Importance sampling MeanVarianceAccumulator mvaB = new MeanVarianceAccumulator(); MeanVarianceAccumulator mvaInvA = new MeanVarianceAccumulator(); Gamma bPrior = Gamma.FromShapeAndRate(B.Shape, B.Rate); q = bPrior; double shift = (product.Shape - product.Power) * Math.Log(qPoint) - shape2 * Math.Log(A.Rate + qPoint * r) + bPrior.GetLogProb(qPoint) - q.GetLogProb(qPoint); for (int i = 0; i < 1000000; i++) { double bSample = q.Sample(); // logf = (y_s-y_p)*log(b) - (s+y_s-pa)*log(r + b*y_r) double logf = (product.Shape - product.Power) * Math.Log(bSample) - shape2 * Math.Log(A.Rate + bSample * r) + bPrior.GetLogProb(bSample) - q.GetLogProb(bSample); double weight = Math.Exp(logf - shift); mvaB.Add(bSample, weight); double invA = (bSample * r + A.Rate) / (shape2 - 1); mvaInvA.Add(invA, weight); } Trace.WriteLine($"b = {mvaB}, {qMean}, {qVariance}"); Trace.WriteLine($"invA = {mvaInvA} {mvaInvA.Variance * (shape2 - 1) / (shape2 - 2) + mvaInvA.Mean * mvaInvA.Mean / (shape2 - 2)}, {iaMean}, {iaVariance}"); Trace.WriteLine($"aMarginal = {aMarginal}"); } } else { // Compute the moments of a^(1/a.Power) // aMean = shape2/(b y_r + a_r) // aVariance = E[shape2*(shape2+1)/(b y_r + a_r)^2] - aMean^2 = var(shape2/(b y_r + a_r)) + E[shape2/(b y_r + a_r)^2] // = shape2^2*var(1/(b y_r + a_r)) + shape2*(var(1/(b y_r + a_r)) + (aMean/shape2)^2) double[] gDerivatives = new double[] { g, -r * g2, 2 * g2 * g * r2, -6 * g2 * g2 * r2 * r }; double gMean, gVariance; GaussianOp_Laplace.LaplaceMoments(q, gDerivatives, dlogfs(qPoint, product, A), out gMean, out gVariance); double aMean = shape2 * gMean; double aVariance = shape2 * shape2 * gVariance + shape2 * (gVariance + gMean * gMean); aMarginal = GammaPower.FromGamma(Gamma.FromMeanAndVariance(aMean, aVariance), result.Power); } } result.SetToRatio(aMarginal, A, GammaProductOp_Laplace.ForceProper); if (double.IsNaN(result.Shape) || double.IsNaN(result.Rate)) { throw new InferRuntimeException("result is nan"); } return(result); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="AAverageConditional(GammaPower, GammaPower)"]/*'/> public static GammaPower AAverageConditional2([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result) { if (sum.IsUniform() || b.IsUniform()) { result.SetToUniform(); return(result); } if (sum.IsPointMass && sum.Point == 0) { if (b.IsPointMass || a.Shape / a.Power <= b.Shape / b.Power) { result.Point = 0; return(result); } else { result.SetToUniform(); return(result); } } if (sum.IsProper() && b.IsProper()) { sum.GetMeanAndVariance(out double sumMean, out double sumVariance); b.GetMeanAndVariance(out double bMean, out double bVariance); double rMean = sumMean - bMean; double rVariance = sumVariance + bVariance; double aVariance = a.GetVariance(); if (aVariance >= rVariance) { if (rMean < 0) { // If b cannot be less than x, then sum cannot be less than x. double tailProbability = 0.1; double bLowerBound = b.GetQuantile(tailProbability); // If sum cannot be more than x, then b cannot be more than x. double sumUpperBound = sum.GetQuantile(1 - tailProbability); if (sumUpperBound > bLowerBound) { // Compute the mean of the truncated distributions. // sum is truncated to [bLowerBound, infinity] sumMean = TruncatedGammaPowerGetMean(sum, bLowerBound, double.PositiveInfinity); // b is truncated to [0, sumUpperBound] bMean = TruncatedGammaPowerGetMean(b, 0, sumUpperBound); rMean = sumMean - bMean; } else { //result.Point = 0; //return result; } } if (rMean > 0) { result = GammaPower.FromMeanAndVariance(rMean, rVariance, result.Power); if (ForceProper && result.Power == 1 && result.Shape < 1) { result.Shape = 1; // Set rate such that shape/rate = rMean result.Rate = result.Shape / rMean; } return(result); } } } // logZ = sum.GetLogAverageOf(toSum) // where toSum = SumAverageConditional(a, b) // Compute the derivatives wrt a.Rate to get the message to A. if (sum.Power == 1) { GetGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance); mean += b.GetMean(); variance += b.GetVariance(); GetGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr); GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ); return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ)); } else if (sum.Power == -1) { bool aHasInfiniteMean = (a.Power == -1 && a.Shape <= 1); bool bHasInfiniteMean = (b.Power == -1 && b.Shape <= 1); if (aHasInfiniteMean) { if (bHasInfiniteMean) { if (a.Shape <= b.Shape) { return(sum); } else { result.SetToUniform(); return(result); } } else { return(sum); } } else if (bHasInfiniteMean) { result.SetToUniform(); return(result); } GetInverseGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance); mean += b.GetMean(); variance += b.GetVariance(); if (variance > double.MaxValue) { return(GammaPower.Uniform(a.Power)); //throw new NotSupportedException(); } GetInverseGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr); if (sum.IsPointMass && sum.Point == 0) { return(GammaPower.PointMass(0, a.Power)); } GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ); return(GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ)); } else { throw new NotImplementedException($"sum.Power == {sum.Power}"); } }