/// <summary> /// Set the parameters to match the moments of a mixture distribution. /// </summary> /// <param name="dist1">The first distribution</param> /// <param name="weight1">The first weight</param> /// <param name="dist2">The second distribution</param> /// <param name="weight2">The second weight</param> public void SetToSum(double weight1, TruncatedGaussian dist1, double weight2, TruncatedGaussian dist2) { if (weight1 + weight2 == 0) { SetToUniform(); } else if (weight1 + weight2 < 0) { throw new ArgumentException("weight1 (" + weight1 + ") + weight2 (" + weight2 + ") < 0"); } else if (weight1 == 0) { SetTo(dist2); } else if (weight2 == 0) { SetTo(dist1); } // if dist1 == dist2 then we must return dist1, with no roundoff error else if (dist1.Equals(dist2)) { SetTo(dist1); } else if (double.IsPositiveInfinity(weight1)) { if (double.IsPositiveInfinity(weight2)) { throw new ArgumentException("both weights are infinity"); } else { SetTo(dist1); } } else if (double.IsPositiveInfinity(weight2)) { SetTo(dist2); } else if (dist1.LowerBound == dist2.LowerBound && dist1.UpperBound == dist2.UpperBound) { Gaussian.SetToSum(weight1, dist1.Gaussian, weight2, dist2.Gaussian); } else { throw new NotImplementedException(); } }
/// <summary> /// Set this distribution equal to the approximate product of a and b /// </summary> /// <param name="a"></param> /// <param name="b"></param> /// <remarks> /// Since WrappedGaussians are not closed under multiplication, the result is approximate. /// </remarks> public void SetToProduct(WrappedGaussian a, WrappedGaussian b) { if (a.Period < b.Period) { SetToProduct(b, a); return; } // a.Period >= b.Period if (a.IsUniform()) { SetTo(b); return; } if (b.IsUniform()) { SetTo(a); return; } if (a.IsPointMass) { if (b.IsPointMass && !a.Point.Equals(b.Point)) { throw new AllZeroException(); } Point = a.Point; return; } if (b.IsPointMass) { Point = b.Point; return; } // (a,b) are not uniform or point mass double ratio = a.Period / b.Period; int intRatio = (int)Math.Round(ratio); if (Math.Abs(ratio - intRatio) > a.Period * 1e-4) { throw new ArgumentException("a.Period (" + a.Period + ") is not a multiple of b.Period (" + b.Period + ")"); } this.Period = a.Period; // a.Period = k*b.Period, k >= 1 // because one period is a multiple of the other, we only need to sum over one set of shifts. // otherwise, we would need to sum over two sets of shifts. double ma, va, mb, vb; a.Gaussian.GetMeanAndVariance(out ma, out va); b.Gaussian.GetMeanAndVariance(out mb, out vb); double diff = (ma - mb) / b.Period; #if true // approximate using only the one best shift int k = (int)Math.Round(diff); Gaussian bShifted = new Gaussian(mb + k * b.Period, vb); Gaussian.SetToProduct(a.Gaussian, bShifted); #else // we will sum over shifts from kMin to kMax, numbering intRatio in total int kMin, kMax; if (intRatio % 2 == 1) { // odd number of shifts int kMid = (int)Math.Round(diff); int halfRatio = intRatio / 2; kMin = kMid - halfRatio; kMax = kMid + halfRatio; } else { // even number of shifts int kMid = (int)Math.Floor(diff); int halfRatio = intRatio / 2; kMin = kMid - halfRatio + 1; kMax = kMid + halfRatio; } if (kMax - kMin != intRatio - 1) { throw new ApplicationException("kMax - kMin != intRatio-1"); } // exclude shifts that are too far away double sa = Math.Sqrt(va + vb); double lowerBound = ma - 5 * sa; double upperBound = ma + 5 * sa; // find the largest k such that (mb + k*Lb <= lowerBound) double kLower = Math.Floor((lowerBound - mb) / b.Period); if (kLower > kMin) { kMin = (int)kLower; } // find the smallest k such that (mb + k*Lb >= upperBound) double kUpper = Math.Ceiling((upperBound - mb) / b.Period); if (kUpper < kMax) { kMax = (int)kUpper; } if (kMax - kMin > 100) { throw new ApplicationException("kMax - kMin = " + (kMax - kMin)); } double totalWeight = Double.NegativeInfinity; for (int k = kMin; k <= kMax; k++) { Gaussian bShifted = new Gaussian(mb + k * b.Period, vb); Gaussian product = a.Gaussian * bShifted; double weight = a.Gaussian.GetLogAverageOf(bShifted); if (double.IsNegativeInfinity(totalWeight)) { Gaussian.SetTo(product); totalWeight = weight; } else { Gaussian.SetToSum(1.0, Gaussian, Math.Exp(weight - totalWeight), product); totalWeight = MMath.LogSumExp(totalWeight, weight); } } #endif if (double.IsNaN(Gaussian.MeanTimesPrecision)) { throw new ApplicationException("result is nan"); } Normalize(); }