예제 #1
0
 /// <summary>
 /// Set the parameters to match the moments of a mixture distribution.
 /// </summary>
 /// <param name="dist1">The first distribution</param>
 /// <param name="weight1">The first weight</param>
 /// <param name="dist2">The second distribution</param>
 /// <param name="weight2">The second weight</param>
 public void SetToSum(double weight1, TruncatedGaussian dist1, double weight2, TruncatedGaussian dist2)
 {
     if (weight1 + weight2 == 0)
     {
         SetToUniform();
     }
     else if (weight1 + weight2 < 0)
     {
         throw new ArgumentException("weight1 (" + weight1 + ") + weight2 (" + weight2 + ") < 0");
     }
     else if (weight1 == 0)
     {
         SetTo(dist2);
     }
     else if (weight2 == 0)
     {
         SetTo(dist1);
     }
     // if dist1 == dist2 then we must return dist1, with no roundoff error
     else if (dist1.Equals(dist2))
     {
         SetTo(dist1);
     }
     else if (double.IsPositiveInfinity(weight1))
     {
         if (double.IsPositiveInfinity(weight2))
         {
             throw new ArgumentException("both weights are infinity");
         }
         else
         {
             SetTo(dist1);
         }
     }
     else if (double.IsPositiveInfinity(weight2))
     {
         SetTo(dist2);
     }
     else if (dist1.LowerBound == dist2.LowerBound && dist1.UpperBound == dist2.UpperBound)
     {
         Gaussian.SetToSum(weight1, dist1.Gaussian, weight2, dist2.Gaussian);
     }
     else
     {
         throw new NotImplementedException();
     }
 }
예제 #2
0
        /// <summary>
        /// Set this distribution equal to the approximate product of a and b
        /// </summary>
        /// <param name="a"></param>
        /// <param name="b"></param>
        /// <remarks>
        /// Since WrappedGaussians are not closed under multiplication, the result is approximate.
        /// </remarks>
        public void SetToProduct(WrappedGaussian a, WrappedGaussian b)
        {
            if (a.Period < b.Period)
            {
                SetToProduct(b, a);
                return;
            }
            // a.Period >= b.Period
            if (a.IsUniform())
            {
                SetTo(b);
                return;
            }
            if (b.IsUniform())
            {
                SetTo(a);
                return;
            }
            if (a.IsPointMass)
            {
                if (b.IsPointMass && !a.Point.Equals(b.Point))
                {
                    throw new AllZeroException();
                }
                Point = a.Point;
                return;
            }
            if (b.IsPointMass)
            {
                Point = b.Point;
                return;
            }
            // (a,b) are not uniform or point mass
            double ratio    = a.Period / b.Period;
            int    intRatio = (int)Math.Round(ratio);

            if (Math.Abs(ratio - intRatio) > a.Period * 1e-4)
            {
                throw new ArgumentException("a.Period (" + a.Period + ") is not a multiple of b.Period (" + b.Period + ")");
            }
            this.Period = a.Period;
            // a.Period = k*b.Period, k >= 1
            // because one period is a multiple of the other, we only need to sum over one set of shifts.
            // otherwise, we would need to sum over two sets of shifts.
            double ma, va, mb, vb;

            a.Gaussian.GetMeanAndVariance(out ma, out va);
            b.Gaussian.GetMeanAndVariance(out mb, out vb);
            double diff = (ma - mb) / b.Period;

#if true
            // approximate using only the one best shift
            int      k        = (int)Math.Round(diff);
            Gaussian bShifted = new Gaussian(mb + k * b.Period, vb);
            Gaussian.SetToProduct(a.Gaussian, bShifted);
#else
            // we will sum over shifts from kMin to kMax, numbering intRatio in total
            int kMin, kMax;
            if (intRatio % 2 == 1)
            {
                // odd number of shifts
                int kMid      = (int)Math.Round(diff);
                int halfRatio = intRatio / 2;
                kMin = kMid - halfRatio;
                kMax = kMid + halfRatio;
            }
            else
            {
                // even number of shifts
                int kMid      = (int)Math.Floor(diff);
                int halfRatio = intRatio / 2;
                kMin = kMid - halfRatio + 1;
                kMax = kMid + halfRatio;
            }
            if (kMax - kMin != intRatio - 1)
            {
                throw new ApplicationException("kMax - kMin != intRatio-1");
            }
            // exclude shifts that are too far away
            double sa         = Math.Sqrt(va + vb);
            double lowerBound = ma - 5 * sa;
            double upperBound = ma + 5 * sa;
            // find the largest k such that (mb + k*Lb <= lowerBound)
            double kLower = Math.Floor((lowerBound - mb) / b.Period);
            if (kLower > kMin)
            {
                kMin = (int)kLower;
            }
            // find the smallest k such that (mb + k*Lb >= upperBound)
            double kUpper = Math.Ceiling((upperBound - mb) / b.Period);
            if (kUpper < kMax)
            {
                kMax = (int)kUpper;
            }
            if (kMax - kMin > 100)
            {
                throw new ApplicationException("kMax - kMin = " + (kMax - kMin));
            }
            double totalWeight = Double.NegativeInfinity;
            for (int k = kMin; k <= kMax; k++)
            {
                Gaussian bShifted = new Gaussian(mb + k * b.Period, vb);
                Gaussian product  = a.Gaussian * bShifted;
                double   weight   = a.Gaussian.GetLogAverageOf(bShifted);
                if (double.IsNegativeInfinity(totalWeight))
                {
                    Gaussian.SetTo(product);
                    totalWeight = weight;
                }
                else
                {
                    Gaussian.SetToSum(1.0, Gaussian, Math.Exp(weight - totalWeight), product);
                    totalWeight = MMath.LogSumExp(totalWeight, weight);
                }
            }
#endif
            if (double.IsNaN(Gaussian.MeanTimesPrecision))
            {
                throw new ApplicationException("result is nan");
            }
            Normalize();
        }