internal static MessageFcnInfo GetMessageFcnInfo(BasicTransformContext context, IMethodInvokeExpression imie) { MethodInfo method = (MethodInfo)imie.Method.Method.MethodInfo; MessageFcnInfo fcnInfo = context.InputAttributes.Get <MessageFcnInfo>(method); if (fcnInfo == null) { FactorManager.FactorInfo info = context.InputAttributes.Get <FactorManager.FactorInfo>(imie); if (info != null) { fcnInfo = info.GetMessageFcnInfoFromFactor(); } else { ParameterInfo[] parameters = method.GetParameters(); fcnInfo = new MessageFcnInfo(method, parameters) { DependencyInfo = FactorManager.GetDependencyInfo(method) }; } //context.InputAttributes.Set(method, fcnInfo); } return(fcnInfo); }
/// <summary> /// Writes the XML documentation for a parameter of given message operator using a given XML writer. /// </summary> /// <param name="writer">The XML writer.</param> /// <param name="factorInfo">The factor the message operator is for.</param> /// <param name="messageFunctionInfo">The message operator.</param> /// <param name="parameter">The parameter.</param> private static void WriteMessageOperatorParameterDescription( XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo, ParameterInfo parameter) { writer.WriteStartElement("param"); writer.WriteAttributeString("name", parameter.Name); if (parameter.Name == "result") { writer.WriteString("Modified to contain the outgoing message."); } else if (parameter.Name == "resultIndex") { writer.WriteString("Index of the "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" for which a message is desired."); } else { FactorEdge edge; if (!messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge)) { // TODO: skip methods which aren't message operators } else { string field = edge.ToString(); Type type = parameter.ParameterType; if (edge.IsOutgoingMessage) { bool isFresh = parameter.IsDefined(typeof(FreshAttribute), false); writer.WriteFormatString("{0} message to ", isFresh ? "Outgoing" : "Previous outgoing"); writer.WriteElementFormatString("c", parameter.Name.Substring(3)); writer.WriteString("."); } else if (factorInfo.ParameterTypes.ContainsKey(field)) { if (type == factorInfo.ParameterTypes[field]) { writer.WriteString("Constant value for "); writer.WriteElementFormatString("c", field); writer.WriteString("."); } else { writer.WriteString("Incoming message from "); writer.WriteElementFormatString("c", field); writer.WriteString("."); foreach (FactorEdge rrange in messageFunctionInfo.Requirements) { FactorEdge range = rrange; if (range.ParameterName == field) { writer.WriteString(" Must be a proper distribution. If "); if (Util.IsIList(factorInfo.ParameterTypes[field])) { foreach (FactorEdge range2 in messageFunctionInfo.Dependencies) { if (range2.ParameterName == field) { range = range.Intersect(range2); } } if (range.MinCount == 1) { writer.WriteString(range.ContainsIndex ? "the element at resultIndex is " : "all elements are "); } else if (range.ContainsAllOthers) { writer.WriteString(range.ContainsIndex ? "any element is " : "any element besides resultIndex is "); } } writer.WriteString("uniform, the result will be uniform."); } } } } else { writer.WriteString("Buffer "); writer.WriteElementFormatString("c", parameter.Name); writer.WriteString("."); } } } writer.WriteEndElement(); }
/// <summary> /// Writes the remarks section of the XML documentation for a given message operator using a given XML writer. /// </summary> /// <param name="writer">The XML writer.</param> /// <param name="factorInfo">The factor the message operator is for.</param> /// <param name="messageFunctionInfo">The message operator.</param> private static void WriteMessageOperatorRemarks( XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo) { bool isConstraint = factorInfo.Method.ReturnType == typeof(void); // must use the parameter names of the factor, not the operator method, because the operator method may not have parameters for all of the factor edges string argsString = StringUtil.CollectionToString(factorInfo.ParameterNames, ","); string childString = isConstraint ? string.Empty : factorInfo.ParameterNames[0]; bool childIsRandom = false; ParameterInfo[] parameters = messageFunctionInfo.Method.GetParameters(); string randomFieldsString = string.Empty; string randomFieldsMinusTarget = string.Empty; string randomFieldsMinusTargetAndChild = string.Empty; foreach (ParameterInfo parameter in parameters) { if (parameter.Name != "result" && parameter.Name != "resultIndex") { FactorEdge edge; if (messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge)) { string field = edge.ToString(); Type type = parameter.ParameterType; if (!edge.IsOutgoingMessage && factorInfo.ParameterTypes.ContainsKey(field) && type != factorInfo.ParameterTypes[field]) { if (!isConstraint && factorInfo.ParameterNames[0] == field) { childIsRandom = true; } randomFieldsString += randomFieldsString == string.Empty ? field : "," + field; if (field != messageFunctionInfo.TargetParameter) { randomFieldsMinusTarget += randomFieldsMinusTarget == string.Empty ? field : "," + field; if (factorInfo.ParameterNames[0] != field) { randomFieldsMinusTargetAndChild += randomFieldsMinusTargetAndChild == string.Empty ? field : "," + field; } } } } } } writer.WriteStartElement("remarks"); writer.WriteStartElement("para"); if (messageFunctionInfo.Method.Name == "AverageLogFactor") { if (randomFieldsString == string.Empty) { writer.WriteString("The formula for the result is "); writer.WriteElementFormatString("c", "log(factor({0}))", argsString); } else { if (factorInfo.IsDeterministicFactor) { writer.WriteString("In Variational Message Passing, the evidence contribution of a deterministic factor is zero"); } else { writer.WriteString("The formula for the result is "); writer.WriteElementFormatString("c", "sum_({0}) p({0}) log(factor({1}))", randomFieldsString, argsString); } } writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for VMP."); } else if (messageFunctionInfo.Method.Name == "LogAverageFactor") { writer.WriteString("The formula for the result is "); if (randomFieldsString == string.Empty) { writer.WriteElementFormatString("c", "log(factor({0}))", argsString); } else { writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString); } writer.WriteString("."); } else if (messageFunctionInfo.Method.Name == "LogEvidenceRatio") { writer.WriteString("The formula for the result is "); if (randomFieldsString == string.Empty) { writer.WriteElementFormatString("c", "log(factor({0}))", argsString); } else { if (childIsRandom) { writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}) / sum_{2} p({2}) messageTo({2}))", randomFieldsString, argsString, childString); } else { writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString); } } writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for EP."); } else if (messageFunctionInfo.Suffix == "Conditional") { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" conditioned on the given values."); } else if (messageFunctionInfo.Suffix == "AverageConditional") { if (randomFieldsMinusTarget == string.Empty) { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" conditioned on the given values."); } else { writer.WriteString("The outgoing message is a distribution matching the moments of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" as the random arguments are varied. The formula is "); writer.WriteElementFormatString("c", "proj[p({2}) sum_({1}) p({1}) factor({0})]/p({2})", argsString, randomFieldsMinusTarget, messageFunctionInfo.TargetParameter); writer.WriteString("."); } } else if (messageFunctionInfo.Suffix == "AverageLogarithm") { if (randomFieldsMinusTarget == string.Empty) { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" conditioned on the given values."); } else { if (factorInfo.IsDeterministicFactor) { if (childString == messageFunctionInfo.TargetParameter) { writer.WriteString("The outgoing message is a distribution matching the moments of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" as the random arguments are varied. The formula is "); writer.WriteElementFormatString("c", "proj[sum_({1}) p({1}) factor({0})]", argsString, randomFieldsMinusTarget); } else { if (childIsRandom && randomFieldsMinusTargetAndChild == string.Empty) { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" with "); writer.WriteElementFormatString("c", childString); writer.WriteString(" integrated out. The formula is "); writer.WriteElementFormatString("c", "sum_{1} p({1}) factor({0})", argsString, childString); } else { writer.WriteString("The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(". "); if (childIsRandom) { writer.WriteString("Because the factor is deterministic, "); writer.WriteElementFormatString("c", childString); writer.WriteString(" is integrated out before taking the logarithm. The formula is "); writer.WriteElementFormatString( "c", "exp(sum_({1}) p({1}) log(sum_{2} p({2}) factor({0})))", argsString, randomFieldsMinusTargetAndChild, childString); } else { writer.WriteString("The formula is "); writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget); } } } } else { writer.WriteString( "The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(". The formula is "); writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget); } writer.WriteString("."); } } writer.WriteEndElement(); writer.WriteEndElement(); }
// Determine the message type of target from the message type of the factor arguments protected void ProcessFactor(IExpression factor, MessageDirection direction) { NodeInfo info = GetNodeInfo(factor); // fill in argumentTypes Dictionary <string, Type> argumentTypes = new Dictionary <string, Type>(); Dictionary <string, IExpression> arguments = new Dictionary <string, IExpression>(); for (int i = 0; i < info.info.ParameterNames.Count; i++) { string parameterName = info.info.ParameterNames[i]; // Create message info. 'isForward' says whether the message // out is in the forward or backward direction bool isChild = info.isReturnOrOut[i]; IExpression arg = info.arguments[i]; bool isConstant = !CodeRecognizer.IsStochastic(context, arg); if (isConstant) { arguments[parameterName] = arg; Type inwardType = arg.GetExpressionType(); argumentTypes[parameterName] = inwardType; } else if (!isChild) { IExpression msgExpr = GetMessageExpression(arg, fwdMessageVars); if (msgExpr == null) { return; } arguments[parameterName] = msgExpr; Type inwardType = msgExpr.GetExpressionType(); if (inwardType == null) { Error("inferred an incorrect message type for " + arg); return; } argumentTypes[parameterName] = inwardType; } else if (direction == MessageDirection.Backwards) { IExpression msgExpr = GetMessageExpression(arg, bckMessageVars); if (msgExpr == null) { //Console.WriteLine("creating backward message for "+arg); CreateBackwardMessageFromForward(arg, null); msgExpr = GetMessageExpression(arg, bckMessageVars); if (msgExpr == null) { return; } } arguments[parameterName] = msgExpr; Type inwardType = msgExpr.GetExpressionType(); if (inwardType == null) { Error("inferred an incorrect message type for " + arg); return; } argumentTypes[parameterName] = inwardType; } } IAlgorithm alg = algorithm; Algorithm algAttr = context.InputAttributes.Get <Algorithm>(info.imie); if (algAttr != null) { alg = algAttr.algorithm; } List <ICompilerAttribute> factorAttributes = context.InputAttributes.GetAll <ICompilerAttribute>(info.imie); string methodSuffix = alg.GetOperatorMethodSuffix(factorAttributes); // infer types of children for (int i = 0; i < info.info.ParameterNames.Count; i++) { string parameterName = info.info.ParameterNames[i]; bool isChild = info.isReturnOrOut[i]; if (isChild != (direction == MessageDirection.Forwards)) { continue; } IExpression target = info.arguments[i]; bool isConstant = !CodeRecognizer.IsStochastic(context, target); if (isConstant) { continue; } IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd == null) { continue; } Type targetType = null; MessageFcnInfo fcninfo = null; if (direction == MessageDirection.Forwards) { try { fcninfo = GetMessageFcnInfo(info.info, "Init", parameterName, argumentTypes); } catch (Exception) { try { fcninfo = GetMessageFcnInfo(info.info, methodSuffix + "Init", parameterName, argumentTypes); } catch (Exception ex) { //Error("could not determine message type of "+ivd.Name, ex); try { fcninfo = GetMessageFcnInfo(info.info, methodSuffix, parameterName, argumentTypes); if (fcninfo.PassResult) { throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it takes a result parameter. Please provide a separate Init method."); } if (fcninfo.PassResultIndex) { throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it takes a resultIndex parameter. Please provide a separate Init method."); } } catch (Exception ex2) { if (direction == MessageDirection.Forwards) { Error("could not determine " + direction + " message type of " + ivd.Name + ": " + ex.Message, ex2); continue; } fcninfo = null; } } } if (fcninfo != null) { targetType = fcninfo.Method.ReturnType; if (targetType.IsGenericParameter) { if (direction == MessageDirection.Forwards) { Error("could not determine " + direction + " message type of " + ivd.Name + " in " + StringUtil.MethodFullNameToString(fcninfo.Method)); continue; } fcninfo = null; } } if (fcninfo != null) { VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); try { targetType = MessageTransform.GetDistributionType(ivd.VariableType.DotNetType, target.GetExpressionType(), targetType, true); } catch (Exception ex) { if (direction == MessageDirection.Forwards) { Error("could not determine " + direction + " message type of " + ivd.Name, ex); continue; } fcninfo = null; } } } Dictionary <IVariableDeclaration, IVariableDeclaration> messageVars = (direction == MessageDirection.Forwards) ? fwdMessageVars : bckMessageVars; if (fcninfo != null) { string name = ivd.Name + (direction == MessageDirection.Forwards ? "_F" : "_B"); IVariableDeclaration msgVar; if (!messageVars.TryGetValue(ivd, out msgVar)) { msgVar = Builder.VarDecl(name, targetType); } if (true) { // construct the init expression List <IExpression> args = new List <IExpression>(); ParameterInfo[] parameters = fcninfo.Method.GetParameters(); foreach (ParameterInfo parameter in parameters) { string argName = parameter.Name; if (IsFactoryType(parameter.ParameterType)) { IVariableDeclaration factoryVar = GetFactoryVariable(parameter.ParameterType); args.Add(Builder.VarRefExpr(factoryVar)); } else { FactorEdge factorEdge = fcninfo.factorEdgeOfParameter[parameter.Name]; string factorParameterName = factorEdge.ParameterName; bool isOutgoingMessage = factorEdge.IsOutgoingMessage; if (!arguments.ContainsKey(factorParameterName)) { if (direction == MessageDirection.Forwards) { Error(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it requires '" + parameter.Name + "'. Please provide a separate Init method."); } fcninfo = null; break; } IExpression arg = arguments[factorParameterName]; args.Add(arg); } } if (fcninfo != null) { IMethodInvokeExpression imie = Builder.StaticMethod(fcninfo.Method, args.ToArray()); //IExpression initExpr = MessageTransform.GetDistributionArrayCreateExpression(ivd.VariableType.DotNetType, target.GetExpressionType(), imie, vi); IExpression initExpr = imie; KeyValuePair <IVariableDeclaration, IExpression> key = new KeyValuePair <IVariableDeclaration, IExpression>(msgVar, factor); messageInitExprs[key] = initExpr; } } if (fcninfo != null) { messageVars[ivd] = msgVar; } } if (fcninfo == null) { if (direction == MessageDirection.Forwards) { continue; } //Console.WriteLine("creating backward message for "+target); CreateBackwardMessageFromForward(target, factor); } IExpression msgExpr = GetMessageExpression(target, messageVars); arguments[parameterName] = msgExpr; Type inwardType = msgExpr.GetExpressionType(); argumentTypes[parameterName] = inwardType; } }