/// <summary> /// Set the mean and variance to match the moments of a mixture of two Gammas. /// </summary> /// <param name="weight1">The first weight</param> /// <param name="dist1">The first Gamma</param> /// <param name="weight2">The second weight</param> /// <param name="dist2">The second Gamma</param> public void SetToSum(double weight1, Gamma dist1, double weight2, Gamma dist2) { SetTo(Gaussian.WeightedSum <Gamma>(this, weight1, dist1, weight2, dist2)); }
/// <summary> /// Sets the mean and variance to match a mixture of two GammaPower distributions. /// </summary> /// <param name="weight1">The first weight</param> /// <param name="dist1">The first distribution</param> /// <param name="weight2">The second weight</param> /// <param name="dist2">The second distribution</param> public void SetToSum(double weight1, GammaPower dist1, double weight2, GammaPower dist2) { this.Power = dist1.Power; SetTo(Gaussian.WeightedSum(this, weight1, dist1, weight2, dist2)); }
/// <summary> /// Set the parameters to match the moments of a mixture distribution. /// </summary> /// <param name="dist1">The first distribution</param> /// <param name="weight1">The first weight</param> /// <param name="dist2">The second distribution</param> /// <param name="weight2">The second weight</param> public void SetToSum(double weight1, Beta dist1, double weight2, Beta dist2) { if (AllowImproperSum) { SetTo(Gaussian.WeightedSum <Beta>(this, weight1, dist1, weight2, dist2)); return; } if (weight1 + weight2 == 0) { SetToUniform(); } else if (weight1 + weight2 < 0) { throw new ArgumentException("weight1 (" + weight1 + ") + weight2 (" + weight2 + ") < 0"); } else if (weight1 == 0) { SetTo(dist2); } else if (weight2 == 0) { SetTo(dist1); } // if dist1 == dist2 then we must return dist1, with no roundoff error else if (dist1.Equals(dist2)) { SetTo(dist1); } else if (double.IsPositiveInfinity(weight1)) { if (double.IsPositiveInfinity(weight2)) { throw new ArgumentException("both weights are infinity"); } else { SetTo(dist1); } } else if (double.IsPositiveInfinity(weight2)) { SetTo(dist2); } else { double minTrue, minFalse; if (dist1.IsPointMass) { if (dist2.IsPointMass) { if (!dist1.Point.Equals(dist2.Point)) { throw new AllZeroException("dist1.Point = " + dist1.Point + Environment.NewLine + "dist2.Point = " + dist2.Point); } Point = dist1.Point; return; } else { minTrue = dist2.TrueCount; minFalse = dist2.FalseCount; } } else if (dist2.IsPointMass) { minTrue = dist1.TrueCount; minFalse = dist1.FalseCount; } else { minTrue = Math.Min(dist1.TrueCount, dist2.TrueCount); minFalse = Math.Min(dist1.FalseCount, dist2.FalseCount); } // algorithm: we choose the result to have the same mean and variance as the mixture // provided that all PseudoCounts are greater than the smallest PseudoCount in the mixture. // The result has the form (s*m,s*(1-m)) where the mean m is fixed and s satisfies // s*m >= min_i dist[i].TrueCount i.e. s >= minTrue/m // s*(1-m) >= min_i dist[i].FalseCount i.e. s >= minFalse/(1-m) // if weight2 < 0 then we want dist1[k] >= min(dist2[k],s*m[k]) i.e. s*m[k] <= dist1[k] when dist1[k] < dist2[k] if (minTrue == 0 && minFalse == 0) { TrueCount = 0; FalseCount = 0; } else { Beta momentMatch = Gaussian.WeightedSum <Beta>(this, weight1, dist1, weight2, dist2); double mean = momentMatch.GetMean(); // minTrue > 0 and minFalse > 0 otherwise GetMean would have thrown an exception. Assert.IsTrue(minTrue > 0); Assert.IsTrue(minFalse > 0); double boundTrue = minTrue / mean; double boundFalse = minFalse / (1 - mean); double bound; bool boundViolated; if (weight1 > 0) { if (weight2 > 0) { bound = Math.Max(boundTrue, boundFalse); boundViolated = (momentMatch.TotalCount < bound); } else { bound = Double.PositiveInfinity; if (dist1.TrueCount < dist2.TrueCount) { bound = boundTrue; } if (dist1.FalseCount < dist2.FalseCount) { bound = Math.Min(bound, boundFalse); } boundViolated = (momentMatch.TotalCount > bound); } } else { // weight1 < 0 bound = Double.PositiveInfinity; if (dist2.TrueCount < dist1.TrueCount) { bound = boundTrue; } if (dist2.FalseCount < dist1.FalseCount) { bound = Math.Min(bound, boundFalse); } boundViolated = (momentMatch.TotalCount > bound); } if (boundViolated) { TrueCount = bound * mean; FalseCount = bound * (1 - mean); } else { SetTo(momentMatch); } } } }