public static VectorGaussianWishart Combine(VectorGaussian position, Wishart orientation, VectorGaussianWishart result) { if (orientation.IsUniform()) { result.SetToUniform(); } else if (position.IsUniform()) { result.SetTo(orientation.Shape, orientation.Rate, Vector.Zero(2), 0); } else { PositiveDefiniteMatrix rateTimesPrecision = new PositiveDefiniteMatrix(2, 2); rateTimesPrecision.SetToProduct(orientation.Rate, position.Precision); double trace = MathHelpers.Invert(rateTimesPrecision).Trace(); Vector positionMean = position.MeanTimesPrecision * MathHelpers.Invert(position.Precision); result.SetTo(orientation.Shape, orientation.Rate, positionMean, orientation.Dimension / (orientation.Shape * trace)); } return result; }
/// <summary> /// Asks the distribution whether it is uniform /// </summary> /// <returns>True or false</returns> public bool IsUniform() { return(!IncludePrior && InducingDist.IsUniform()); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="ArrayFromVectorOp"]/message_doc[@name="ArrayAverageConditional{GaussianList}(IList{Gaussian}, VectorGaussian, GaussianList)"]/*'/> /// <typeparam name="GaussianList">The type of the resulting array.</typeparam> public static GaussianList ArrayAverageConditional <GaussianList>( [NoInit] IList <Gaussian> array, [SkipIfUniform] VectorGaussian vector, GaussianList result) where GaussianList : IList <Gaussian> { if (result.Count != vector.Dimension) { throw new ArgumentException("vector.Dimension (" + vector.Dimension + ") != result.Count (" + result.Count + ")"); } int length = result.Count; bool allPointMass = array.All(g => g.IsPointMass); if (allPointMass) { // efficient special case for (int i = 0; i < length; i++) { double x = array[i].Point; // -prec*(x-m) = -prec*x + prec*m double dlogp = vector.MeanTimesPrecision[i]; for (int j = 0; j < length; j++) { dlogp -= vector.Precision[i, j] * array[j].Point; } double ddlogp = -vector.Precision[i, i]; result[i] = Gaussian.FromDerivatives(x, dlogp, ddlogp, false); } } else if (vector.IsPointMass) { // efficient special case Vector mean = vector.Point; for (int i = 0; i < length; i++) { result[i] = Gaussian.PointMass(mean[i]); } } else if (vector.IsUniform()) { for (int i = 0; i < length; i++) { result[i] = Gaussian.Uniform(); } } else if (array.Any(g => g.IsPointMass)) { // Z = N(m1; m2, V1+V2) // logZ = -0.5 (m1-m2)'inv(V1+V2)(m1-m2) // dlogZ = (m1-m2)'inv(V1+V2) dm2 // ddlogZ = -dm2'inv(V1+V2) dm2 Vector mean = Vector.Zero(length); PositiveDefiniteMatrix variance = new PositiveDefiniteMatrix(length, length); vector.GetMeanAndVariance(mean, variance); for (int i = 0; i < length; i++) { if (array[i].IsUniform()) { continue; } double m, v; array[i].GetMeanAndVariance(out m, out v); variance[i, i] += v; mean[i] -= m; } PositiveDefiniteMatrix precision = variance.Inverse(); Vector meanTimesPrecision = precision * mean; for (int i = 0; i < length; i++) { if (array[i].IsUniform()) { result[i] = Gaussian.FromMeanAndVariance(mean[i], variance[i, i]); } else { double alpha = meanTimesPrecision[i]; double beta = precision[i, i]; result[i] = GaussianOp.GaussianFromAlphaBeta(array[i], alpha, beta, false); } } } else { // Compute inv(V1+V2)*(m1-m2) as inv(V2)*inv(inv(V1) + inv(V2))*(inv(V1)*m1 + inv(V2)*m2) - inv(V2)*m2 = inv(V2)*(m - m2) // Compute inv(V1+V2) as inv(V2)*inv(inv(V1) + inv(V2))*inv(V2) - inv(V2) PositiveDefiniteMatrix precision = (PositiveDefiniteMatrix)vector.Precision.Clone(); Vector meanTimesPrecision = vector.MeanTimesPrecision.Clone(); for (int i = 0; i < length; i++) { Gaussian g = array[i]; precision[i, i] += g.Precision; meanTimesPrecision[i] += g.MeanTimesPrecision; } bool fastMethod = true; if (fastMethod) { bool isPosDef; // this destroys precision LowerTriangularMatrix precisionChol = precision.CholeskyInPlace(out isPosDef); if (!isPosDef) { throw new PositiveDefiniteMatrixException(); } // variance = inv(precisionChol*precisionChol') = inv(precisionChol)'*inv(precisionChol) = varianceChol*varianceChol' // this destroys meanTimesPrecision var mean = meanTimesPrecision.PredivideBy(precisionChol); mean = mean.PredivideByTranspose(precisionChol); var varianceCholTranspose = precisionChol; // this destroys precisionChol varianceCholTranspose.SetToInverse(precisionChol); for (int i = 0; i < length; i++) { Gaussian g = array[i]; double variance_ii = GetSquaredLengthOfColumn(varianceCholTranspose, i); // works when g is uniform, but not when g is point mass result[i] = Gaussian.FromMeanAndVariance(mean[i], variance_ii) / g; } } else { // equivalent to above, but slower PositiveDefiniteMatrix variance = precision.Inverse(); var mean = variance * meanTimesPrecision; for (int i = 0; i < length; i++) { Gaussian g = array[i]; // works when g is uniform, but not when g is point mass result[i] = Gaussian.FromMeanAndVariance(mean[i], variance[i, i]) / g; } } } return(result); }
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="MatrixVectorProductOp"]/message_doc[@name="AAverageConditional(VectorGaussian, DistributionArray2D{Gaussian, double}, Vector, PositiveDefiniteMatrix, DistributionStructArray2D{Gaussian, double})"]/*'/> public static DistributionStructArray2D <Gaussian, double> AAverageConditional([SkipIfUniform] VectorGaussian product, DistributionArray2D <Gaussian, double> A, Vector BMean, PositiveDefiniteMatrix BVariance, DistributionStructArray2D <Gaussian, double> result) { if (product.IsUniform()) { result.SetToUniform(); return(result); } if (!A.IsPointMass) { throw new ArgumentException("A is not a point mass"); } // logZ = log N(mProduct; A*BMean, vProduct + A*BVariance*A') // = -0.5 (mProduct - A*BMean)' inv(vProduct + A*BVariance*A') (mProduct - A*BMean) - 0.5 logdet(vProduct + A*BVariance*A') // = -0.5 (mProduct - A*BMean)' pPrec inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct (mProduct - A*BMean) // - 0.5 logdet(pProduct + pProduct*A*BVariance*A'*pProduct) + logdet(pProduct) // dlogZ = 0.5 (dA*BMean)' pProduct inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct (mProduct - A*BMean) // +0.5 (mProduct - A*BMean)' pProduct inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct (dA*BMean) // +0.5 (mProduct - A*BMean)' pProduct inv(pProduct + pProduct*A*BVariance*A'*pProduct) (pProduct*dA*BVariance*A'*pProduct + pProduct*A*BVariance*dA'*pProduct) inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct (mProduct - A*BMean) // - 0.5 tr(inv(pProduct + pProduct*A*BVariance*A'*pProduct) (pProduct*dA*BVariance*A'*pProduct + pProduct*A*BVariance*dA'*pProduct)) // dlogZ/dA = pProduct inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct (mProduct - A*BMean) BMean' // + pProduct inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct (mProduct - A*BMean) (mProduct - A*BMean)' pProduct inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct*A*BVariance // - pProduct inv(pProduct + pProduct*A*BVariance*A'*pProduct) pProduct A*BVariance var Amatrix = new Matrix(A.Point); var pProductA = product.Precision * Amatrix; var pProductABV = pProductA * BVariance; PositiveDefiniteMatrix prec = new PositiveDefiniteMatrix(product.Dimension, product.Dimension); prec.SetToSum(product.Precision, pProductABV * pProductA.Transpose()); // pProductA is now free for (int i = 0; i < prec.Rows; i++) { if (prec[i, i] == 0) { prec[i, i] = 1; } } var v = prec.Inverse(); var ABM = Amatrix * BMean; var pProductABM = product.Precision * ABM; var diff = pProductABM; diff.SetToDifference(product.MeanTimesPrecision, pProductABM); // ABM is now free var pProductV = product.Precision * v; var pProductVdiff = ABM; pProductVdiff.SetToProduct(pProductV, diff); var Vdiff = v * diff; pProductV.Scale(-1); pProductV.SetToSumWithOuter(pProductV, 1, pProductVdiff, Vdiff); Matrix dlogZ = pProductA; dlogZ.SetToProduct(pProductV, pProductABV); dlogZ.SetToSumWithOuter(dlogZ, 1, pProductVdiff, BMean); int rows = A.GetLength(0); int cols = A.GetLength(1); for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { double dlogp = dlogZ[i, j]; // for now, we don't compute the second derivative. double ddlogp = -1; result[i, j] = Gaussian.FromDerivatives(A[i, j].Point, dlogp, ddlogp, false); } } return(result); }