예제 #1
0
        internal static MessageFcnInfo GetMessageFcnInfo(BasicTransformContext context, IMethodInvokeExpression imie)
        {
            MethodInfo     method  = (MethodInfo)imie.Method.Method.MethodInfo;
            MessageFcnInfo fcnInfo = context.InputAttributes.Get <MessageFcnInfo>(method);

            if (fcnInfo == null)
            {
                FactorManager.FactorInfo info = context.InputAttributes.Get <FactorManager.FactorInfo>(imie);
                if (info != null)
                {
                    fcnInfo = info.GetMessageFcnInfoFromFactor();
                }
                else
                {
                    ParameterInfo[] parameters = method.GetParameters();
                    fcnInfo = new MessageFcnInfo(method, parameters)
                    {
                        DependencyInfo = FactorManager.GetDependencyInfo(method)
                    };
                }
                //context.InputAttributes.Set(method, fcnInfo);
            }
            return(fcnInfo);
        }
예제 #2
0
        /// <summary>
        /// Writes the XML documentation for a parameter of given message operator using a given XML writer.
        /// </summary>
        /// <param name="writer">The XML writer.</param>
        /// <param name="factorInfo">The factor the message operator is for.</param>
        /// <param name="messageFunctionInfo">The message operator.</param>
        /// <param name="parameter">The parameter.</param>
        private static void WriteMessageOperatorParameterDescription(
            XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo, ParameterInfo parameter)
        {
            writer.WriteStartElement("param");
            writer.WriteAttributeString("name", parameter.Name);

            if (parameter.Name == "result")
            {
                writer.WriteString("Modified to contain the outgoing message.");
            }
            else if (parameter.Name == "resultIndex")
            {
                writer.WriteString("Index of the ");
                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                writer.WriteString(" for which a message is desired.");
            }
            else
            {
                FactorEdge edge;
                if (!messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge))
                {
                    // TODO: skip methods which aren't message operators
                }
                else
                {
                    string field = edge.ToString();
                    Type   type  = parameter.ParameterType;
                    if (edge.IsOutgoingMessage)
                    {
                        bool isFresh = parameter.IsDefined(typeof(FreshAttribute), false);
                        writer.WriteFormatString("{0} message to ", isFresh ? "Outgoing" : "Previous outgoing");
                        writer.WriteElementFormatString("c", parameter.Name.Substring(3));
                        writer.WriteString(".");
                    }
                    else if (factorInfo.ParameterTypes.ContainsKey(field))
                    {
                        if (type == factorInfo.ParameterTypes[field])
                        {
                            writer.WriteString("Constant value for ");
                            writer.WriteElementFormatString("c", field);
                            writer.WriteString(".");
                        }
                        else
                        {
                            writer.WriteString("Incoming message from ");
                            writer.WriteElementFormatString("c", field);
                            writer.WriteString(".");
                            foreach (FactorEdge rrange in messageFunctionInfo.Requirements)
                            {
                                FactorEdge range = rrange;
                                if (range.ParameterName == field)
                                {
                                    writer.WriteString(" Must be a proper distribution. If ");
                                    if (Util.IsIList(factorInfo.ParameterTypes[field]))
                                    {
                                        foreach (FactorEdge range2 in messageFunctionInfo.Dependencies)
                                        {
                                            if (range2.ParameterName == field)
                                            {
                                                range = range.Intersect(range2);
                                            }
                                        }

                                        if (range.MinCount == 1)
                                        {
                                            writer.WriteString(range.ContainsIndex ? "the element at resultIndex is " : "all elements are ");
                                        }
                                        else if (range.ContainsAllOthers)
                                        {
                                            writer.WriteString(range.ContainsIndex ? "any element is " : "any element besides resultIndex is ");
                                        }
                                    }

                                    writer.WriteString("uniform, the result will be uniform.");
                                }
                            }
                        }
                    }
                    else
                    {
                        writer.WriteString("Buffer ");
                        writer.WriteElementFormatString("c", parameter.Name);
                        writer.WriteString(".");
                    }
                }
            }

            writer.WriteEndElement();
        }
예제 #3
0
        /// <summary>
        /// Writes the remarks section of the XML documentation for a given message operator using a given XML writer.
        /// </summary>
        /// <param name="writer">The XML writer.</param>
        /// <param name="factorInfo">The factor the message operator is for.</param>
        /// <param name="messageFunctionInfo">The message operator.</param>
        private static void WriteMessageOperatorRemarks(
            XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo)
        {
            bool isConstraint = factorInfo.Method.ReturnType == typeof(void);

            // must use the parameter names of the factor, not the operator method, because the operator method may not have parameters for all of the factor edges
            string argsString  = StringUtil.CollectionToString(factorInfo.ParameterNames, ",");
            string childString = isConstraint ? string.Empty : factorInfo.ParameterNames[0];

            bool childIsRandom = false;

            ParameterInfo[] parameters                      = messageFunctionInfo.Method.GetParameters();
            string          randomFieldsString              = string.Empty;
            string          randomFieldsMinusTarget         = string.Empty;
            string          randomFieldsMinusTargetAndChild = string.Empty;

            foreach (ParameterInfo parameter in parameters)
            {
                if (parameter.Name != "result" && parameter.Name != "resultIndex")
                {
                    FactorEdge edge;
                    if (messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge))
                    {
                        string field = edge.ToString();
                        Type   type  = parameter.ParameterType;
                        if (!edge.IsOutgoingMessage && factorInfo.ParameterTypes.ContainsKey(field) && type != factorInfo.ParameterTypes[field])
                        {
                            if (!isConstraint && factorInfo.ParameterNames[0] == field)
                            {
                                childIsRandom = true;
                            }

                            randomFieldsString += randomFieldsString == string.Empty ? field : "," + field;
                            if (field != messageFunctionInfo.TargetParameter)
                            {
                                randomFieldsMinusTarget += randomFieldsMinusTarget == string.Empty ? field : "," + field;
                                if (factorInfo.ParameterNames[0] != field)
                                {
                                    randomFieldsMinusTargetAndChild += randomFieldsMinusTargetAndChild == string.Empty ? field : "," + field;
                                }
                            }
                        }
                    }
                }
            }

            writer.WriteStartElement("remarks");
            writer.WriteStartElement("para");
            if (messageFunctionInfo.Method.Name == "AverageLogFactor")
            {
                if (randomFieldsString == string.Empty)
                {
                    writer.WriteString("The formula for the result is ");
                    writer.WriteElementFormatString("c", "log(factor({0}))", argsString);
                }
                else
                {
                    if (factorInfo.IsDeterministicFactor)
                    {
                        writer.WriteString("In Variational Message Passing, the evidence contribution of a deterministic factor is zero");
                    }
                    else
                    {
                        writer.WriteString("The formula for the result is ");
                        writer.WriteElementFormatString("c", "sum_({0}) p({0}) log(factor({1}))", randomFieldsString, argsString);
                    }
                }

                writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for VMP.");
            }
            else if (messageFunctionInfo.Method.Name == "LogAverageFactor")
            {
                writer.WriteString("The formula for the result is ");

                if (randomFieldsString == string.Empty)
                {
                    writer.WriteElementFormatString("c", "log(factor({0}))", argsString);
                }
                else
                {
                    writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString);
                }

                writer.WriteString(".");
            }
            else if (messageFunctionInfo.Method.Name == "LogEvidenceRatio")
            {
                writer.WriteString("The formula for the result is ");

                if (randomFieldsString == string.Empty)
                {
                    writer.WriteElementFormatString("c", "log(factor({0}))", argsString);
                }
                else
                {
                    if (childIsRandom)
                    {
                        writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}) / sum_{2} p({2}) messageTo({2}))", randomFieldsString, argsString, childString);
                    }
                    else
                    {
                        writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString);
                    }
                }

                writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for EP.");
            }
            else if (messageFunctionInfo.Suffix == "Conditional")
            {
                writer.WriteString("The outgoing message is the factor viewed as a function of ");
                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                writer.WriteString(" conditioned on the given values.");
            }
            else if (messageFunctionInfo.Suffix == "AverageConditional")
            {
                if (randomFieldsMinusTarget == string.Empty)
                {
                    writer.WriteString("The outgoing message is the factor viewed as a function of ");
                    writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                    writer.WriteString(" conditioned on the given values.");
                }
                else
                {
                    writer.WriteString("The outgoing message is a distribution matching the moments of ");
                    writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                    writer.WriteString(" as the random arguments are varied. The formula is ");
                    writer.WriteElementFormatString("c", "proj[p({2}) sum_({1}) p({1}) factor({0})]/p({2})", argsString, randomFieldsMinusTarget, messageFunctionInfo.TargetParameter);
                    writer.WriteString(".");
                }
            }
            else if (messageFunctionInfo.Suffix == "AverageLogarithm")
            {
                if (randomFieldsMinusTarget == string.Empty)
                {
                    writer.WriteString("The outgoing message is the factor viewed as a function of ");
                    writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                    writer.WriteString(" conditioned on the given values.");
                }
                else
                {
                    if (factorInfo.IsDeterministicFactor)
                    {
                        if (childString == messageFunctionInfo.TargetParameter)
                        {
                            writer.WriteString("The outgoing message is a distribution matching the moments of ");
                            writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                            writer.WriteString(" as the random arguments are varied. The formula is ");
                            writer.WriteElementFormatString("c", "proj[sum_({1}) p({1}) factor({0})]", argsString, randomFieldsMinusTarget);
                        }
                        else
                        {
                            if (childIsRandom && randomFieldsMinusTargetAndChild == string.Empty)
                            {
                                writer.WriteString("The outgoing message is the factor viewed as a function of ");
                                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                                writer.WriteString(" with ");
                                writer.WriteElementFormatString("c", childString);
                                writer.WriteString(" integrated out. The formula is ");
                                writer.WriteElementFormatString("c", "sum_{1} p({1}) factor({0})", argsString, childString);
                            }
                            else
                            {
                                writer.WriteString("The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except ");
                                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                                writer.WriteString(". ");

                                if (childIsRandom)
                                {
                                    writer.WriteString("Because the factor is deterministic, ");
                                    writer.WriteElementFormatString("c", childString);
                                    writer.WriteString(" is integrated out before taking the logarithm. The formula is ");
                                    writer.WriteElementFormatString(
                                        "c", "exp(sum_({1}) p({1}) log(sum_{2} p({2}) factor({0})))", argsString, randomFieldsMinusTargetAndChild, childString);
                                }
                                else
                                {
                                    writer.WriteString("The formula is ");
                                    writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget);
                                }
                            }
                        }
                    }
                    else
                    {
                        writer.WriteString(
                            "The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except ");
                        writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                        writer.WriteString(". The formula is ");
                        writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget);
                    }

                    writer.WriteString(".");
                }
            }

            writer.WriteEndElement();
            writer.WriteEndElement();
        }
예제 #4
0
        // Determine the message type of target from the message type of the factor arguments
        protected void ProcessFactor(IExpression factor, MessageDirection direction)
        {
            NodeInfo info = GetNodeInfo(factor);
            // fill in argumentTypes
            Dictionary <string, Type>        argumentTypes = new Dictionary <string, Type>();
            Dictionary <string, IExpression> arguments     = new Dictionary <string, IExpression>();

            for (int i = 0; i < info.info.ParameterNames.Count; i++)
            {
                string parameterName = info.info.ParameterNames[i];
                // Create message info. 'isForward' says whether the message
                // out is in the forward or backward direction
                bool        isChild    = info.isReturnOrOut[i];
                IExpression arg        = info.arguments[i];
                bool        isConstant = !CodeRecognizer.IsStochastic(context, arg);
                if (isConstant)
                {
                    arguments[parameterName] = arg;
                    Type inwardType = arg.GetExpressionType();
                    argumentTypes[parameterName] = inwardType;
                }
                else if (!isChild)
                {
                    IExpression msgExpr = GetMessageExpression(arg, fwdMessageVars);
                    if (msgExpr == null)
                    {
                        return;
                    }
                    arguments[parameterName] = msgExpr;
                    Type inwardType = msgExpr.GetExpressionType();
                    if (inwardType == null)
                    {
                        Error("inferred an incorrect message type for " + arg);
                        return;
                    }
                    argumentTypes[parameterName] = inwardType;
                }
                else if (direction == MessageDirection.Backwards)
                {
                    IExpression msgExpr = GetMessageExpression(arg, bckMessageVars);
                    if (msgExpr == null)
                    {
                        //Console.WriteLine("creating backward message for "+arg);
                        CreateBackwardMessageFromForward(arg, null);
                        msgExpr = GetMessageExpression(arg, bckMessageVars);
                        if (msgExpr == null)
                        {
                            return;
                        }
                    }
                    arguments[parameterName] = msgExpr;
                    Type inwardType = msgExpr.GetExpressionType();
                    if (inwardType == null)
                    {
                        Error("inferred an incorrect message type for " + arg);
                        return;
                    }
                    argumentTypes[parameterName] = inwardType;
                }
            }
            IAlgorithm alg     = algorithm;
            Algorithm  algAttr = context.InputAttributes.Get <Algorithm>(info.imie);

            if (algAttr != null)
            {
                alg = algAttr.algorithm;
            }
            List <ICompilerAttribute> factorAttributes = context.InputAttributes.GetAll <ICompilerAttribute>(info.imie);
            string methodSuffix = alg.GetOperatorMethodSuffix(factorAttributes);

            // infer types of children
            for (int i = 0; i < info.info.ParameterNames.Count; i++)
            {
                string parameterName = info.info.ParameterNames[i];
                bool   isChild       = info.isReturnOrOut[i];
                if (isChild != (direction == MessageDirection.Forwards))
                {
                    continue;
                }
                IExpression target     = info.arguments[i];
                bool        isConstant = !CodeRecognizer.IsStochastic(context, target);
                if (isConstant)
                {
                    continue;
                }
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);
                if (ivd == null)
                {
                    continue;
                }
                Type           targetType = null;
                MessageFcnInfo fcninfo    = null;
                if (direction == MessageDirection.Forwards)
                {
                    try
                    {
                        fcninfo = GetMessageFcnInfo(info.info, "Init", parameterName, argumentTypes);
                    }
                    catch (Exception)
                    {
                        try
                        {
                            fcninfo = GetMessageFcnInfo(info.info, methodSuffix + "Init", parameterName, argumentTypes);
                        }
                        catch (Exception ex)
                        {
                            //Error("could not determine message type of "+ivd.Name, ex);
                            try
                            {
                                fcninfo = GetMessageFcnInfo(info.info, methodSuffix, parameterName, argumentTypes);
                                if (fcninfo.PassResult)
                                {
                                    throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) +
                                                                     " is not suitable for initialization since it takes a result parameter.  Please provide a separate Init method.");
                                }
                                if (fcninfo.PassResultIndex)
                                {
                                    throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) +
                                                                     " is not suitable for initialization since it takes a resultIndex parameter.  Please provide a separate Init method.");
                                }
                            }
                            catch (Exception ex2)
                            {
                                if (direction == MessageDirection.Forwards)
                                {
                                    Error("could not determine " + direction + " message type of " + ivd.Name + ": " + ex.Message, ex2);
                                    continue;
                                }
                                fcninfo = null;
                            }
                        }
                    }
                    if (fcninfo != null)
                    {
                        targetType = fcninfo.Method.ReturnType;
                        if (targetType.IsGenericParameter)
                        {
                            if (direction == MessageDirection.Forwards)
                            {
                                Error("could not determine " + direction + " message type of " + ivd.Name + " in " + StringUtil.MethodFullNameToString(fcninfo.Method));
                                continue;
                            }
                            fcninfo = null;
                        }
                    }
                    if (fcninfo != null)
                    {
                        VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd);
                        try
                        {
                            targetType = MessageTransform.GetDistributionType(ivd.VariableType.DotNetType, target.GetExpressionType(), targetType, true);
                        }
                        catch (Exception ex)
                        {
                            if (direction == MessageDirection.Forwards)
                            {
                                Error("could not determine " + direction + " message type of " + ivd.Name, ex);
                                continue;
                            }
                            fcninfo = null;
                        }
                    }
                }
                Dictionary <IVariableDeclaration, IVariableDeclaration> messageVars = (direction == MessageDirection.Forwards) ? fwdMessageVars : bckMessageVars;
                if (fcninfo != null)
                {
                    string name = ivd.Name + (direction == MessageDirection.Forwards ? "_F" : "_B");
                    IVariableDeclaration msgVar;
                    if (!messageVars.TryGetValue(ivd, out msgVar))
                    {
                        msgVar = Builder.VarDecl(name, targetType);
                    }
                    if (true)
                    {
                        // construct the init expression
                        List <IExpression> args       = new List <IExpression>();
                        ParameterInfo[]    parameters = fcninfo.Method.GetParameters();
                        foreach (ParameterInfo parameter in parameters)
                        {
                            string argName = parameter.Name;
                            if (IsFactoryType(parameter.ParameterType))
                            {
                                IVariableDeclaration factoryVar = GetFactoryVariable(parameter.ParameterType);
                                args.Add(Builder.VarRefExpr(factoryVar));
                            }
                            else
                            {
                                FactorEdge factorEdge          = fcninfo.factorEdgeOfParameter[parameter.Name];
                                string     factorParameterName = factorEdge.ParameterName;
                                bool       isOutgoingMessage   = factorEdge.IsOutgoingMessage;
                                if (!arguments.ContainsKey(factorParameterName))
                                {
                                    if (direction == MessageDirection.Forwards)
                                    {
                                        Error(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it requires '" + parameter.Name +
                                              "'.  Please provide a separate Init method.");
                                    }
                                    fcninfo = null;
                                    break;
                                }
                                IExpression arg = arguments[factorParameterName];
                                args.Add(arg);
                            }
                        }
                        if (fcninfo != null)
                        {
                            IMethodInvokeExpression imie = Builder.StaticMethod(fcninfo.Method, args.ToArray());
                            //IExpression initExpr = MessageTransform.GetDistributionArrayCreateExpression(ivd.VariableType.DotNetType, target.GetExpressionType(), imie, vi);
                            IExpression initExpr = imie;
                            KeyValuePair <IVariableDeclaration, IExpression> key = new KeyValuePair <IVariableDeclaration, IExpression>(msgVar, factor);
                            messageInitExprs[key] = initExpr;
                        }
                    }
                    if (fcninfo != null)
                    {
                        messageVars[ivd] = msgVar;
                    }
                }
                if (fcninfo == null)
                {
                    if (direction == MessageDirection.Forwards)
                    {
                        continue;
                    }
                    //Console.WriteLine("creating backward message for "+target);
                    CreateBackwardMessageFromForward(target, factor);
                }
                IExpression msgExpr = GetMessageExpression(target, messageVars);
                arguments[parameterName] = msgExpr;
                Type inwardType = msgExpr.GetExpressionType();
                argumentTypes[parameterName] = inwardType;
            }
        }