コード例 #1
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="ArrayFromVectorOp"]/message_doc[@name="ArrayAverageConditional{GaussianList}(IList{Gaussian}, VectorGaussian, GaussianList)"]/*'/>
        /// <typeparam name="GaussianList">The type of the resulting array.</typeparam>
        public static GaussianList ArrayAverageConditional <GaussianList>(
            [NoInit] IList <Gaussian> array, [SkipIfUniform] VectorGaussian vector, GaussianList result)
            where GaussianList : IList <Gaussian>
        {
            if (result.Count != vector.Dimension)
            {
                throw new ArgumentException("vector.Dimension (" + vector.Dimension + ") != result.Count (" + result.Count + ")");
            }
            int  length       = result.Count;
            bool allPointMass = array.All(g => g.IsPointMass);

            if (allPointMass)
            {
                // efficient special case
                for (int i = 0; i < length; i++)
                {
                    double x = array[i].Point;
                    // -prec*(x-m) = -prec*x + prec*m
                    double dlogp = vector.MeanTimesPrecision[i];
                    for (int j = 0; j < length; j++)
                    {
                        dlogp -= vector.Precision[i, j] * array[j].Point;
                    }
                    double ddlogp = -vector.Precision[i, i];
                    result[i] = Gaussian.FromDerivatives(x, dlogp, ddlogp, false);
                }
            }
            else if (vector.IsPointMass)
            {
                // efficient special case
                Vector mean = vector.Point;
                for (int i = 0; i < length; i++)
                {
                    result[i] = Gaussian.PointMass(mean[i]);
                }
            }
            else if (vector.IsUniform())
            {
                for (int i = 0; i < length; i++)
                {
                    result[i] = Gaussian.Uniform();
                }
            }
            else if (array.Any(g => g.IsPointMass))
            {
                // Z = N(m1; m2, V1+V2)
                // logZ = -0.5 (m1-m2)'inv(V1+V2)(m1-m2)
                // dlogZ = (m1-m2)'inv(V1+V2) dm2
                // ddlogZ = -dm2'inv(V1+V2) dm2
                Vector mean = Vector.Zero(length);
                PositiveDefiniteMatrix variance = new PositiveDefiniteMatrix(length, length);
                vector.GetMeanAndVariance(mean, variance);
                for (int i = 0; i < length; i++)
                {
                    if (array[i].IsUniform())
                    {
                        continue;
                    }
                    double m, v;
                    array[i].GetMeanAndVariance(out m, out v);
                    variance[i, i] += v;
                    mean[i]        -= m;
                }
                PositiveDefiniteMatrix precision = variance.Inverse();
                Vector meanTimesPrecision        = precision * mean;
                for (int i = 0; i < length; i++)
                {
                    if (array[i].IsUniform())
                    {
                        result[i] = Gaussian.FromMeanAndVariance(mean[i], variance[i, i]);
                    }
                    else
                    {
                        double alpha = meanTimesPrecision[i];
                        double beta  = precision[i, i];
                        result[i] = GaussianOp.GaussianFromAlphaBeta(array[i], alpha, beta, false);
                    }
                }
            }
            else
            {
                // Compute inv(V1+V2)*(m1-m2) as inv(V2)*inv(inv(V1) + inv(V2))*(inv(V1)*m1 + inv(V2)*m2) - inv(V2)*m2 = inv(V2)*(m - m2)
                // Compute inv(V1+V2) as inv(V2)*inv(inv(V1) + inv(V2))*inv(V2) - inv(V2)
                PositiveDefiniteMatrix precision = (PositiveDefiniteMatrix)vector.Precision.Clone();
                Vector meanTimesPrecision        = vector.MeanTimesPrecision.Clone();
                for (int i = 0; i < length; i++)
                {
                    Gaussian g = array[i];
                    precision[i, i]       += g.Precision;
                    meanTimesPrecision[i] += g.MeanTimesPrecision;
                }
                bool fastMethod = true;
                if (fastMethod)
                {
                    bool isPosDef;
                    // this destroys precision
                    LowerTriangularMatrix precisionChol = precision.CholeskyInPlace(out isPosDef);
                    if (!isPosDef)
                    {
                        throw new PositiveDefiniteMatrixException();
                    }
                    // variance = inv(precisionChol*precisionChol') = inv(precisionChol)'*inv(precisionChol) = varianceChol*varianceChol'
                    // this destroys meanTimesPrecision
                    var mean = meanTimesPrecision.PredivideBy(precisionChol);
                    mean = mean.PredivideByTranspose(precisionChol);
                    var varianceCholTranspose = precisionChol;
                    // this destroys precisionChol
                    varianceCholTranspose.SetToInverse(precisionChol);
                    for (int i = 0; i < length; i++)
                    {
                        Gaussian g           = array[i];
                        double   variance_ii = GetSquaredLengthOfColumn(varianceCholTranspose, i);
                        // works when g is uniform, but not when g is point mass
                        result[i] = Gaussian.FromMeanAndVariance(mean[i], variance_ii) / g;
                    }
                }
                else
                {
                    // equivalent to above, but slower
                    PositiveDefiniteMatrix variance = precision.Inverse();
                    var mean = variance * meanTimesPrecision;
                    for (int i = 0; i < length; i++)
                    {
                        Gaussian g = array[i];
                        // works when g is uniform, but not when g is point mass
                        result[i] = Gaussian.FromMeanAndVariance(mean[i], variance[i, i]) / g;
                    }
                }
            }
            return(result);
        }