예제 #1
0
파일: Weight.cs 프로젝트: tauheedul/infer
        /// <summary>
        /// Computes the sum of a geometric series <c>1 + w + w^2 + w^3 + ...</c>,
        /// where <c>w</c> is a given weight.
        /// </summary>
        /// <param name="weight">The weight.</param>
        /// <returns>The computed sum, or <see cref="Infinity"/> if the sum diverges.</returns>
        public static Weight Closure(Weight weight)
        {
            const double Eps = 1e-20;

            if (weight.LogValue < -Eps)
            {
                // The series converges
                return(new Weight(-MMath.Log1MinusExp(weight.logValue)));
            }

            // The series diverges
            return(Weight.Infinity);
        }
예제 #2
0
파일: Weight.cs 프로젝트: tauheedul/infer
        /// <summary>
        /// Computes the sum of a geometric series <c>1 + w + w^2 + w^3 + ...</c>,
        /// where <c>w</c> is a given weight.
        /// If the sum diverges, replaces the infinite sum by a finite sum with a lot of terms.
        /// </summary>
        /// <param name="weight">The weight.</param>
        /// <returns>The computed sum.</returns>
        public static Weight ApproximateClosure(Weight weight)
        {
            const double Eps       = 1e-20;
            const double TermCount = 10000;

            if (weight.LogValue < -Eps)
            {
                // The series converges
                return(new Weight(-MMath.Log1MinusExp(weight.LogValue)));
            }

            if (weight.LogValue < Eps)
            {
                // The series diverges, geometric progression formula does not apply
                return(new Weight(Math.Log(TermCount)));
            }

            // Compute geometric progression with a lot of terms
            return(new Weight(MMath.LogExpMinus1(weight.LogValue * TermCount) - MMath.LogExpMinus1(weight.LogValue)));
        }
예제 #3
0
        /// <summary>EP message to <c>value</c>.</summary>
        /// <param name="enterOne">Incoming message from <c>enterOne</c>. Must be a proper distribution. If uniform, the result will be uniform.</param>
        /// <param name="selector">Incoming message from <c>selector</c>.</param>
        /// <param name="value">Incoming message from <c>value</c>.</param>
        /// <param name="index">Constant value for <c>index</c>.</param>
        /// <param name="result">Modified to contain the outgoing message.</param>
        /// <returns>
        ///   <paramref name="result" />
        /// </returns>
        /// <remarks>
        ///   <para>The outgoing message is a distribution matching the moments of <c>value</c> as the random arguments are varied. The formula is <c>proj[p(value) sum_(enterOne,selector) p(enterOne,selector) factor(enterOne,selector,value,index)]/p(value)</c>.</para>
        /// </remarks>
        /// <exception cref="ImproperMessageException">
        ///   <paramref name="enterOne" /> is not a proper distribution.</exception>
        /// <typeparam name="TDist">The type of the distribution over the variable entering the gate.</typeparam>
        public static TDist ValueAverageConditional <TDist>([SkipIfAllUniform] TDist enterOne, Discrete selector, TDist value, int index, TDist result)
            where TDist : ICloneable, SettableToUniform, SettableToWeightedSum <TDist>, SettableTo <TDist>, CanGetLogAverageOf <TDist>, CanGetLogNormalizer
        {
            if (selector == null)
            {
                throw new ArgumentNullException("selector");
            }

            double logProbSum = selector.GetLogProb(index);

            if (logProbSum == 0.0)
            {
                result.SetTo(enterOne);
            }
            else if (double.IsNegativeInfinity(logProbSum))
            {
                result.SetToUniform();
            }
            else
            {
                double logProb = MMath.Log1MinusExp(logProbSum);
                double shift   = Math.Max(logProbSum, logProb);

                // Avoid (-Infinity) - (-Infinity)
                if (double.IsNegativeInfinity(shift))
                {
                    throw new AllZeroException();
                }

                TDist uniform = (TDist)result.Clone();
                uniform.SetToUniform();
                result.SetToSum(Math.Exp(logProbSum - shift - enterOne.GetLogAverageOf(value)), enterOne, Math.Exp(logProb - shift + uniform.GetLogNormalizer()), uniform);
            }

            return(result);
        }
예제 #4
0
        /// <summary>EP message to <c>value</c>.</summary>
        /// <param name="enterPartial">Incoming message from <c>enterPartial</c>. Must be a proper distribution. If any element is uniform, the result will be uniform.</param>
        /// <param name="selector">Incoming message from <c>selector</c>. Must be a proper distribution. If uniform, the result will be uniform.</param>
        /// <param name="value">Incoming message from <c>value</c>.</param>
        /// <param name="indices">Constant value for <c>indices</c>.</param>
        /// <param name="result">Modified to contain the outgoing message.</param>
        /// <returns>
        ///   <paramref name="result" />
        /// </returns>
        /// <remarks>
        ///   <para>The outgoing message is a distribution matching the moments of <c>value</c> as the random arguments are varied. The formula is <c>proj[p(value) sum_(enterPartial,selector) p(enterPartial,selector) factor(enterPartial,selector,value,indices)]/p(value)</c>.</para>
        /// </remarks>
        /// <exception cref="ImproperMessageException">
        ///   <paramref name="enterPartial" /> is not a proper distribution.</exception>
        /// <exception cref="ImproperMessageException">
        ///   <paramref name="selector" /> is not a proper distribution.</exception>
        /// <typeparam name="TDist">The type of the distribution over the variable entering the gate.</typeparam>
        public static TDist ValueAverageConditional <TDist>(
            [SkipIfUniform] IList <TDist> enterPartial, [SkipIfUniform] Bernoulli selector, TDist value, int[] indices, TDist result)
            where TDist : ICloneable, SettableToUniform, SettableTo <TDist>, SettableToWeightedSum <TDist>, CanGetLogAverageOf <TDist>, CanGetLogNormalizer
        {
            if (enterPartial == null)
            {
                throw new ArgumentNullException("enterPartial");
            }

            if (indices == null)
            {
                throw new ArgumentNullException("indices");
            }

            if (indices.Length != enterPartial.Count)
            {
                throw new ArgumentException("indices.Length != enterPartial.Count");
            }

            if (2 < enterPartial.Count)
            {
                throw new ArgumentException("enterPartial.Count should be 2 or 1");
            }

            if (indices.Length == 0)
            {
                throw new ArgumentException("indices.Length == 0");
            }

            // TODO: use pre-allocated buffers
            double logProbSum   = (indices[0] == 0) ? selector.GetLogProbTrue() : selector.GetLogProbFalse();
            double logWeightSum = logProbSum;

            if (!double.IsNegativeInfinity(logProbSum))
            {
                logWeightSum -= enterPartial[0].GetLogAverageOf(value);
                result.SetTo(enterPartial[0]);
            }

            if (indices.Length > 1)
            {
                for (int i = 1; i < indices.Length; i++)
                {
                    double logProb = (indices[i] == 0) ? selector.GetLogProbTrue() : selector.GetLogProbFalse();
                    logProbSum += logProb;
                    double shift = Math.Max(logWeightSum, logProb);

                    // Avoid (-Infinity) - (-Infinity)
                    if (double.IsNegativeInfinity(shift))
                    {
                        if (i == 1)
                        {
                            throw new AllZeroException();
                        }

                        // Do nothing
                    }
                    else
                    {
                        double logWeightShifted = logProb - shift;
                        if (!double.IsNegativeInfinity(logWeightShifted))
                        {
                            logWeightShifted -= enterPartial[i].GetLogAverageOf(value);
                            result.SetToSum(Math.Exp(logWeightSum - shift), result, Math.Exp(logWeightShifted), enterPartial[i]);
                            logWeightSum = MMath.LogSumExp(logWeightSum, logWeightShifted + shift);
                        }
                    }
                }
            }

            if (indices.Length < 2)
            {
                double logProb = MMath.Log1MinusExp(logProbSum);
                double shift   = Math.Max(logWeightSum, logProb);
                if (double.IsNegativeInfinity(shift))
                {
                    throw new AllZeroException();
                }

                var uniform = (TDist)result.Clone();
                uniform.SetToUniform();
                double logWeight = logProb + uniform.GetLogNormalizer();
                result.SetToSum(Math.Exp(logWeightSum - shift), result, Math.Exp(logWeight - shift), uniform);
            }

            return(result);
        }
예제 #5
0
        /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BeliefPropagationGateEnterPartialOp"]/message_doc[@name="ValueAverageConditional{TDist}(IList{TDist}, Discrete, TDist, int[], TDist)"]/*'/>
        /// <typeparam name="TDist">The type of the distribution over the variable entering the gate.</typeparam>
        public static TDist ValueAverageConditional <TDist>(
            [SkipIfUniform] IList <TDist> enterPartial, [SkipIfUniform] Discrete selector, TDist value, int[] indices, TDist result)
            where TDist : ICloneable, SettableToUniform, SettableTo <TDist>, SettableToWeightedSum <TDist>, CanGetLogAverageOf <TDist>, CanGetLogNormalizer
        {
            if (enterPartial == null)
            {
                throw new ArgumentNullException(nameof(enterPartial));
            }

            if (selector == null)
            {
                throw new ArgumentNullException(nameof(selector));
            }

            if (indices == null)
            {
                throw new ArgumentNullException(nameof(indices));
            }

            if (indices.Length != enterPartial.Count)
            {
                throw new ArgumentException("indices.Length != enterPartial.Count");
            }

            if (selector.Dimension < enterPartial.Count)
            {
                throw new ArgumentException("selector.Dimension < enterPartial.Count");
            }

            if (indices.Length == 0)
            {
                throw new ArgumentException("indices.Length == 0");
            }

            // TODO: use pre-allocated buffers
            double logProbSum   = selector.GetLogProb(indices[0]);
            double logWeightSum = logProbSum;

            if (!double.IsNegativeInfinity(logWeightSum))
            {
                // Subtract to avoid double-counting since selector already contains this quantity
                // See IntCasesOp.IAverageConditional
                double logAverage = enterPartial[0].GetLogAverageOf(value);
                if (!double.IsNegativeInfinity(logAverage))
                {
                    logWeightSum -= logAverage;
                    result.SetTo(enterPartial[0]);
                }
            }

            if (indices.Length > 1)
            {
                for (int i = 1; i < indices.Length; i++)
                {
                    double logProb = selector.GetLogProb(indices[i]);
                    logProbSum = MMath.LogSumExp(logProbSum, logProb);
                    double shift = Math.Max(logWeightSum, logProb);

                    // Avoid (-Infinity) - (-Infinity)
                    if (double.IsNegativeInfinity(shift))
                    {
                        if (i == selector.Dimension - 1)
                        {
                            throw new AllZeroException();
                        }

                        // Do nothing
                    }
                    else
                    {
                        double logWeightShifted = logProb - shift;
                        if (!double.IsNegativeInfinity(logWeightShifted))
                        {
                            double logAverage = enterPartial[i].GetLogAverageOf(value);
                            if (!double.IsNegativeInfinity(logAverage))
                            {
                                logWeightShifted -= logAverage;
                                result.SetToSum(Math.Exp(logWeightSum - shift), result, Math.Exp(logWeightShifted), enterPartial[i]);
                                logWeightSum = MMath.LogSumExp(logWeightSum, logWeightShifted + shift);
                            }
                        }
                    }
                }
            }

            if (indices.Length < selector.Dimension)
            {
                double logProb = MMath.Log1MinusExp(logProbSum);
                double shift   = Math.Max(logWeightSum, logProb);
                if (double.IsNegativeInfinity(shift))
                {
                    throw new AllZeroException();
                }

                var uniform = (TDist)result.Clone();
                uniform.SetToUniform();
                double logWeight = logProb + uniform.GetLogNormalizer();
                result.SetToSum(Math.Exp(logWeightSum - shift), result, Math.Exp(logWeight - shift), uniform);
            }

            return(result);
        }