public void NotFactorInfo() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Factor), "Not").GetMethodInfo()); Console.WriteLine(info); Assert.True(info.IsDeterministicFactor); IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>(); parameterTypes["Not"] = typeof(Bernoulli); parameterTypes["B"] = typeof(Bernoulli); parameterTypes["result"] = typeof(Bernoulli); MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Not", parameterTypes); Assert.False(fcninfo.PassResult); Assert.False(fcninfo.PassResultIndex); Assert.Equal(1, fcninfo.Dependencies.Count); Assert.Equal(1, fcninfo.Requirements.Count); Console.WriteLine(fcninfo); }
public void MissingMethodFailure() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new Func <double, double, double>(Factor.Gaussian)); Dictionary <string, Type> parameterTypes = new Dictionary <string, Type>(); parameterTypes["sample"] = typeof(Gaussian); parameterTypes["mean"] = typeof(Gaussian); parameterTypes["precision"] = typeof(Gaussian); try { MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Sample", parameterTypes); Assert.True(false, "Did not throw an exception"); } catch (ArgumentException ex) { Console.WriteLine("Correctly failed with exception: " + ex); } }
public void TypeConstraintFailureFactorInfo() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(ShiftAlpha), "ToFactor<>").GetMethodInfo()); Console.WriteLine(info); Dictionary <string, Type> parameterTypes = new Dictionary <string, Type>(); parameterTypes["factor"] = typeof(double); parameterTypes["result"] = typeof(double); try { MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Variable", parameterTypes); Assert.True(false, "Did not throw an exception"); } catch (ArgumentException ex) { Console.WriteLine("Correctly failed with exception: " + ex); } }
//[Fact] // should take 140ms // turn on PrintStatistics in Binding.cs internal void SpeedTest() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Factor), "ReplicateWithMarginal<>", typeof(bool[])).GetMethodInfo()); IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>(); Type ba = typeof(DistributionStructArray <Bernoulli, bool>); Type baa = typeof(DistributionRefArray <DistributionStructArray <Bernoulli, bool>, bool[]>); parameterTypes["Uses"] = baa; parameterTypes["Def"] = ba; parameterTypes["Marginal"] = ba; parameterTypes["result"] = baa; MessageFcnInfo fcninfo; fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Uses", parameterTypes); parameterTypes["result"] = ba; parameterTypes["resultIndex"] = typeof(int); fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Uses", parameterTypes); }
internal void CheckMessageFcnInfo(MessageFcnInfo fcninfo, FactorManager.FactorInfo info) { Assert.True(fcninfo.Suffix != null); Assert.True(fcninfo.TargetParameter != null); Dictionary <string, Type> parameterTypes = new Dictionary <string, Type>(); foreach (KeyValuePair <string, Type> parameter in fcninfo.GetParameterTypes()) { parameterTypes[parameter.Key] = parameter.Value; } try { MessageFcnInfo fcninfo2 = info.GetMessageFcnInfo(factorManager, fcninfo.Suffix, fcninfo.TargetParameter, parameterTypes); Assert.Equal(fcninfo2.Method, fcninfo.Method); } catch (NotSupportedException) { Assert.True(fcninfo.NotSupportedMessage != null); } }
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie) { CheckMethodArgumentCount(imie); if (CodeRecognizer.IsInfer(imie)) { inferCount++; object decl = Recognizer.GetDeclaration(imie.Arguments[0]); if (decl != null && !context.InputAttributes.Has <IsInferred>(decl)) { context.InputAttributes.Set(decl, new IsInferred()); } // the arguments must not be substituted for their values, so we don't call ConvertExpression var newArgs = imie.Arguments.Select(CodeRecognizer.RemoveCast); IMethodInvokeExpression infer = Builder.MethodInvkExpr(); infer.Method = imie.Method; infer.Arguments.AddRange(newArgs); context.InputAttributes.CopyObjectAttributesTo(imie, context.OutputAttributes, infer); return(infer); } IExpression converted = base.ConvertMethodInvoke(imie); if (converted is IMethodInvokeExpression mie) { foreach (IExpression arg in mie.Arguments) { if (arg is IAddressOutExpression iaoe) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iaoe.Expression); if (ivd != null) { FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, mie); if (info != null && info.IsDeterministicFactor && !context.InputAttributes.Has <DerivedVariable>(ivd)) { context.InputAttributes.Set(ivd, new DerivedVariable()); } } } } } return(converted); }
/// <summary> /// Attach DerivedVariable attributes to newly created variables /// </summary> /// <param name="iae"></param> /// <returns></returns> protected override IExpression ConvertAssign(IAssignExpression iae) { iae = (IAssignExpression)base.ConvertAssign(iae); if (iae.Expression is IMethodInvokeExpression imie) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target); if (ivd != null) { bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd); if (!isDerived) { FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie); if (info.IsDeterministicFactor) { context.InputAttributes.Set(ivd, new DerivedVariable()); } } } } return(iae); }
public void ReplicateMessageFcnInfo() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Clone), "Replicate<>", typeof(bool)).GetMethodInfo()); var parameterTypes = new Dictionary <string, Type> { ["Uses"] = typeof(DistributionArray <Bernoulli>), ["Def"] = typeof(Bernoulli), ["result"] = typeof(Bernoulli) //typeof(DistributionArray<Bernoulli>); }; MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Def", parameterTypes); CheckMessageFcnInfo(fcninfo, info); parameterTypes["Uses"] = typeof(Bernoulli[]); fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Def", parameterTypes); CheckMessageFcnInfo(fcninfo, info); DependencyInformation depInfo; depInfo = FactorManager.GetDependencyInfo(fcninfo.Method); }
public void IsPositiveFactorInfo() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new Func <double, bool>(Factor.IsPositive)); Console.WriteLine(info); Assert.True(info.IsDeterministicFactor); IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>(); parameterTypes["isPositive"] = typeof(bool); parameterTypes["x"] = typeof(Gaussian); MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "x", parameterTypes); Assert.True(fcninfo.NotSupportedMessage == null); CheckMessageFcnInfo(fcninfo, info); DependencyInformation depInfo = FactorManager.GetDependencyInfo(fcninfo.Method); Console.WriteLine("Dependencies:"); Console.WriteLine(StringUtil.ToString(depInfo.Dependencies)); Console.WriteLine("Requirements:"); Console.WriteLine(StringUtil.ToString(depInfo.Requirements)); Assert.Throws <NotSupportedException>(() => { fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "x", parameterTypes); }); bool found = false; foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos("AverageLogarithm", null, null)) { CheckMessageFcnInfo(fcninfo2, info); if (fcninfo2.TargetParameter.Equals("x")) { Assert.True(fcninfo2.NotSupportedMessage != null); found = true; } } Assert.True(found); fcninfo = info.GetMessageFcnInfo(factorManager, "LogAverageFactor", "", parameterTypes); CheckMessageFcnInfo(fcninfo, info); }
internal static MessageFcnInfo GetMessageFcnInfo(BasicTransformContext context, IMethodInvokeExpression imie) { MethodInfo method = (MethodInfo)imie.Method.Method.MethodInfo; MessageFcnInfo fcnInfo = context.InputAttributes.Get <MessageFcnInfo>(method); if (fcnInfo == null) { FactorManager.FactorInfo info = context.InputAttributes.Get <FactorManager.FactorInfo>(imie); if (info != null) { fcnInfo = info.GetMessageFcnInfoFromFactor(); } else { ParameterInfo[] parameters = method.GetParameters(); fcnInfo = new MessageFcnInfo(method, parameters); fcnInfo.DependencyInfo = FactorManager.GetDependencyInfo(method); } //context.InputAttributes.Set(method, fcnInfo); } return(fcnInfo); }
/// <summary> /// Writes the XML documentation for a given message operator using a given XML writer. /// </summary> /// <param name="writer">The XML writer.</param> /// <param name="factorInfo">The factor the message operator is for.</param> /// <param name="messageFunctionInfo">The message operator.</param> private static void WriteMessageFunctionDocumentation( XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo) { string methodName = QuoteCodeElementName(StringUtil.MethodSignatureToString( messageFunctionInfo.Method, useFullName: false, omitParameterNames: true)); writer.WriteStartElement("message_doc"); writer.WriteAttributeString("name", methodName); WriteMessageOperatorSummary(writer, messageFunctionInfo); foreach (ParameterInfo parameter in messageFunctionInfo.Method.GetParameters()) { WriteMessageOperatorParameterDescription(writer, factorInfo, messageFunctionInfo, parameter); } WriteMessageOperatorReturns(writer, factorInfo, messageFunctionInfo); WriteMessageOperatorRemarks(writer, factorInfo, messageFunctionInfo); WriteMessageOperatorExceptionSpec(writer, factorInfo, messageFunctionInfo); writer.WriteEndElement(); }
public void ConstrainEqualRandomFactorInfo() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Constrain), "EqualRandom<,>", typeof(bool), typeof(Bernoulli)).GetMethodInfo()); Assert.Equal(2, info.ParameterNames.Count); var parameterTypes = new Dictionary <string, Type> { ["value"] = typeof(Bernoulli), ["dist"] = typeof(Bernoulli), ["result"] = typeof(Bernoulli) }; MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "value", parameterTypes); Assert.False(fcninfo.PassResult); Assert.False(fcninfo.PassResultIndex); Assert.Equal(1, fcninfo.Dependencies.Count); Assert.Equal(1, fcninfo.Requirements.Count); bool verbose = false; if (verbose) { Console.WriteLine("All MessageFcnInfos:"); } int count = 0; foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos()) { if (verbose) { Console.WriteLine(fcninfo2); } CheckMessageFcnInfo(fcninfo2, info); count++; } //Assert.Equal(4, count); }
public void UsesEqualDefFactorInfo() { DependencyInformation depInfo; FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Clone), "UsesEqualDef<>", typeof(bool)).GetMethodInfo()); Assert.True(!info.IsDeterministicFactor); var parameterTypes = new Dictionary <string, Type> { ["Uses"] = typeof(Bernoulli[]), ["Def"] = typeof(Bernoulli), ["resultIndex"] = typeof(int), ["result"] = typeof(Bernoulli) }; for (int i = 0; i < 3; i++) { MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Uses", parameterTypes); Assert.True(fcninfo.PassResult); Assert.True(fcninfo.PassResultIndex); Assert.Equal(2, fcninfo.Dependencies.Count); Assert.Equal(1, fcninfo.Requirements.Count); Assert.True(fcninfo.SkipIfAllUniform); Assert.Equal(1, fcninfo.Triggers.Count); if (i == 0) { depInfo = FactorManager.GetDependencyInfo(fcninfo.Method); } } parameterTypes.Remove("resultIndex"); MessageFcnInfo fcninfo2 = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Def", parameterTypes); Assert.True(fcninfo2.SkipIfAllUniform); Assert.Equal(1, fcninfo2.Triggers.Count); depInfo = FactorManager.GetDependencyInfo(fcninfo2.Method); }
private MessageFcnInfo GetMessageFcnInfo(FactorManager.FactorInfo info, string methodSuffix, string targetParameter, IDictionary <string, Type> parameterTypes) { List <string> factoryParameters = new List <string>(); while (true) { MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, methodSuffix, targetParameter, parameterTypes); ParameterInfo[] parameters = fcninfo.Method.GetParameters(); bool newFactoryParameters = false; foreach (ParameterInfo parameter in parameters) { if (parameterTypes.ContainsKey(parameter.Name)) { continue; } if (IsFactoryType(parameter.ParameterType)) { newFactoryParameters = true; Type[] typeArgs = parameter.ParameterType.GetGenericArguments(); Type itemType = typeArgs[0]; Type arrayType = Distribution.MakeDistributionArrayType(itemType, 1); Type factoryType = typeof(IArrayFactory <,>).MakeGenericType(itemType, arrayType); parameterTypes[parameter.Name] = factoryType; factoryParameters.Add(parameter.Name); } } if (newFactoryParameters) { continue; } foreach (string factoryParameter in factoryParameters) { parameterTypes.Remove(factoryParameter); } return(fcninfo); } }
public void DifferenceFactorInfo() { FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Factor), "Difference").GetMethodInfo()); Console.WriteLine(info); Assert.True(info.IsDeterministicFactor); IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>(); parameterTypes["Difference"] = typeof(Gaussian); parameterTypes["A"] = typeof(Gaussian); parameterTypes["B"] = typeof(Gaussian); //parameterTypes["result"] = typeof(Gaussian); MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Difference", parameterTypes); Assert.False(fcninfo.PassResult); Assert.False(fcninfo.PassResultIndex); Assert.Equal(2, fcninfo.Dependencies.Count); Assert.Equal(2, fcninfo.Requirements.Count); Console.WriteLine(fcninfo); Console.WriteLine("Parameter types:"); Console.WriteLine(StringUtil.CollectionToString(fcninfo.GetParameterTypes(), ",")); Console.WriteLine(); try { fcninfo = info.GetMessageFcnInfo(factorManager, "Rubbish", "Difference", parameterTypes); Assert.True(false, "Did not throw an exception"); } catch (ArgumentException ex) { if (!ex.Message.Contains("MissingMethodException")) { Assert.True(false, "Correctly threw exception, but with wrong message"); } Console.WriteLine("Different exception: " + ex); } Console.WriteLine(); try { parameterTypes["result"] = typeof(double); fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Difference", parameterTypes); Assert.True(false, "Did not throw an exception"); } catch (ArgumentException ex) { Console.WriteLine("Correctly failed with exception: " + ex); } Console.WriteLine(); Console.WriteLine("All messages to A:"); int count = 0; foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos(null, "A", null)) { Console.WriteLine(fcninfo2); CheckMessageFcnInfo(fcninfo2, info); count++; } //Assert.Equal(8, count); Console.WriteLine(); Console.WriteLine("All messages to Difference:"); count = 0; foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos(null, "Difference", null)) { Console.WriteLine(fcninfo2); CheckMessageFcnInfo(fcninfo2, info); count++; } Assert.Equal(8, count); Console.WriteLine(); Console.WriteLine("All AverageConditionals:"); count = 0; foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos("AverageConditional", null, null)) { Console.WriteLine(fcninfo2); CheckMessageFcnInfo(fcninfo2, info); count++; } Assert.Equal(12, count); Console.WriteLine(); Console.WriteLine("All MessageFcnInfos:"); count = 0; foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos()) { Console.WriteLine(fcninfo2); CheckMessageFcnInfo(fcninfo2, info); count++; } //Assert.Equal(26, count); }
protected void ProcessDefinition(IExpression expr, IVariableDeclaration targetVar, bool isLhs) { bool targetIsPointMass = false; if (expr is IMethodInvokeExpression imie) { // TODO: consider using a method attribute for this if (Recognizer.IsStaticGenericMethod(imie, new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint)) ) { targetIsPointMass = true; } else { FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie); targetIsPointMass = info.IsDeterministicFactor && ( (info.ReturnedInAllElementsParameterIndex != -1 && ArgumentIsPointMass(imie.Arguments[info.ReturnedInAllElementsParameterIndex])) || imie.Arguments.All(ArgumentIsPointMass) ); } if (targetIsPointMass) { // do this immediately so all uses are updated if (!context.InputAttributes.Has <ForwardPointMass>(targetVar)) { context.OutputAttributes.Set(targetVar, new ForwardPointMass()); } // the rest is done later List <IMethodInvokeExpression> list; if (!variablesDefinedPointMass.TryGetValue(targetVar, out list)) { list = new List <IMethodInvokeExpression>(); variablesDefinedPointMass.Add(targetVar, list); } // this code needs to be synchronized with MessageTransform.ConvertMethodInvoke if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, int, PlaceHolder[]>(Clone.Replicate)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged)) ) { list.Add(imie); } } } if (!targetIsPointMass && !(expr is IArrayCreateExpression)) { variablesDefinedNonPointMass.Add(targetVar); if (variablesDefinedPointMass.ContainsKey(targetVar)) { variablesDefinedPointMass.Remove(targetVar); context.OutputAttributes.Remove <ForwardPointMass>(targetVar); } } bool ArgumentIsPointMass(IExpression arg) { bool IsOut = (arg is IAddressOutExpression); if (CodeRecognizer.IsStochastic(context, arg) && !IsOut) { IVariableDeclaration argVar = Recognizer.GetVariableDeclaration(arg); return((argVar != null) && context.InputAttributes.Has <ForwardPointMass>(argVar)); } else { return(true); } } }
protected override IExpression ConvertAssign(IAssignExpression iae) { IParameterDeclaration ipd = null; IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target); object decl = ivd; if (ivd == null) { ipd = Recognizer.GetParameterDeclaration(iae.Target); if (ipd == null) { return(base.ConvertAssign(iae)); } decl = ipd; } if (iae.Target is IArrayIndexerExpression) { // Gather index variables from the left-hand side of the assignment VariableInformation vi = VariableInformation.GetVariableInformation(context, decl); try { List <IVariableDeclaration[]> indVars = new List <IVariableDeclaration[]>(); Recognizer.AddIndexers(context, indVars, iae.Target); int depth = Recognizer.GetIndexingDepth(iae.Target); // if this statement is actually a constraint, then we don't need to enforce matching of index variables bool isConstraint = context.InputAttributes.Has <Models.Constraint>(context.FindAncestor <IStatement>()); for (int i = 0; i < depth; i++) { vi.SetIndexVariablesAtDepth(i, indVars[i], allowMismatch: isConstraint); } } catch (Exception ex) { Error(ex.Message, ex); } } IAssignExpression ae = (IAssignExpression)base.ConvertAssign(iae); if (ipd == null) { // assignment to a local variable if (ae.Expression is IMethodInvokeExpression) { IMethodInvokeExpression imie = (IMethodInvokeExpression)ae.Expression; // this unfortunately duplicates some of the work done by SetStoch and IsStoch. FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie); if (info != null && info.IsDeterministicFactor && !context.InputAttributes.Has <DerivedVariable>(ivd)) { context.InputAttributes.Set(ivd, new DerivedVariable()); } } if (ae.Expression is ILiteralExpression) { bool isLoopInitializer = (Recognizer.GetAncestorIndexOfLoopBeingInitialized(context) != -1); if (!isLoopInitializer) { Type valueType = ae.Expression.GetExpressionType(); if (Quoter.ShouldInlineType(valueType)) { // inline all future occurrences of this variable with the rhs expression conditionContext.Add(new ConditionBinding(ae.Target, ae.Expression)); } } } } else { // assignment to a method parameter IStatement ist = context.FindAncestor <IStatement>(); if (!context.InputAttributes.Has <Models.Constraint>(ist)) { // mark this statement as a constraint context.OutputAttributes.Set(ist, new Models.Constraint()); } } // a FactorAlgorithm attribute on a variable turns into an Algorithm attribute on its right hand side. var attr = context.InputAttributes.Get <FactorAlgorithm>(decl); if (attr != null) { context.OutputAttributes.Set(ae.Expression, new Algorithm(attr.algorithm)); } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(decl, context.OutputAttributes, ae.Expression); return(ae); }
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie) { if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, ICompilerAttribute, PlaceHolder>(Attrib.Var))) { IVariableReferenceExpression ivre = imie.Arguments[0] as IVariableReferenceExpression; IVariableDeclaration target = ivre.Variable.Resolve(); IExpression expr = CodeRecognizer.RemoveCast(imie.Arguments[1]); AddAttribute(target, expr); return(null); } else if (Recognizer.IsStaticMethod(imie, new Action <object, object>(Attrib.InitialiseTo))) { IVariableReferenceExpression ivre = CodeRecognizer.RemoveCast(imie.Arguments[0]) as IVariableReferenceExpression; IVariableDeclaration target = ivre.Variable.Resolve(); context.OutputAttributes.Set(target, new InitialiseTo(imie.Arguments[1])); return(null); } else if (CodeRecognizer.IsInfer(imie)) { inferCount++; object decl = Recognizer.GetDeclaration(imie.Arguments[0]); if (decl != null && !context.InputAttributes.Has <IsInferred>(decl)) { context.InputAttributes.Set(decl, new IsInferred()); } // the arguments must not be substituted for their values, so we don't call ConvertExpression List <IExpression> newArgs = new List <IExpression>(); foreach (var arg in imie.Arguments) { newArgs.Add(CodeRecognizer.RemoveCast(arg)); } IMethodInvokeExpression mie = Builder.MethodInvkExpr(); mie.Method = imie.Method; mie.Arguments.AddRange(newArgs); context.InputAttributes.CopyObjectAttributesTo(imie, context.OutputAttributes, mie); return(mie); } IExpression converted = base.ConvertMethodInvoke(imie); if (converted is IMethodInvokeExpression) { var mie = (IMethodInvokeExpression)converted; bool isAnd = Recognizer.IsStaticMethod(converted, new Func <bool, bool, bool>(Factors.Factor.And)); bool isOr = Recognizer.IsStaticMethod(converted, new Func <bool, bool, bool>(Factors.Factor.Or)); bool anyArgumentIsLiteral = mie.Arguments.Any(arg => arg is ILiteralExpression); if (anyArgumentIsLiteral) { if (isAnd) { if (mie.Arguments.Any(arg => arg is ILiteralExpression && ((ILiteralExpression)arg).Value.Equals(false))) { return(Builder.LiteralExpr(false)); } // any remaining literals must be true, and therefore can be ignored. var reducedArguments = mie.Arguments.Where(arg => !(arg is ILiteralExpression)); if (reducedArguments.Count() == 1) { return(reducedArguments.First()); } else { return(Builder.LiteralExpr(true)); } } else if (isOr) { if (mie.Arguments.Any(arg => arg is ILiteralExpression && ((ILiteralExpression)arg).Value.Equals(true))) { return(Builder.LiteralExpr(true)); } // any remaining literals must be false, and therefore can be ignored. var reducedArguments = mie.Arguments.Where(arg => !(arg is ILiteralExpression)); if (reducedArguments.Count() == 1) { return(reducedArguments.First()); } else { return(Builder.LiteralExpr(false)); } } else if (Recognizer.IsStaticMethod(converted, new Func <bool, bool>(Factors.Factor.Not))) { bool allArgumentsAreLiteral = mie.Arguments.All(arg => arg is ILiteralExpression); if (allArgumentsAreLiteral) { return(Builder.LiteralExpr(evaluator.Evaluate(mie))); } } } foreach (IExpression arg in mie.Arguments) { if (arg is IAddressOutExpression) { IAddressOutExpression iaoe = (IAddressOutExpression)arg; IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iaoe.Expression); if (ivd != null) { FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, mie); if (info != null && info.IsDeterministicFactor && !context.InputAttributes.Has <DerivedVariable>(ivd)) { context.InputAttributes.Set(ivd, new DerivedVariable()); } } } } } return(converted); }
/// <summary> /// Initializes a new instance of the <see cref="FactorInfoWrapper"/> class. /// </summary> /// <param name="factorInfo">The wrapped factor info.</param> public FactorInfoWrapper(FactorManager.FactorInfo factorInfo) { Debug.Assert(factorInfo != null, "The given factor info cannot be null."); this.FactorInfo = factorInfo; }
/// <summary> /// Writes the XML documentation for a parameter of given message operator using a given XML writer. /// </summary> /// <param name="writer">The XML writer.</param> /// <param name="factorInfo">The factor the message operator is for.</param> /// <param name="messageFunctionInfo">The message operator.</param> /// <param name="parameter">The parameter.</param> private static void WriteMessageOperatorParameterDescription( XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo, ParameterInfo parameter) { writer.WriteStartElement("param"); writer.WriteAttributeString("name", parameter.Name); if (parameter.Name == "result") { writer.WriteString("Modified to contain the outgoing message."); } else if (parameter.Name == "resultIndex") { writer.WriteString("Index of the "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" for which a message is desired."); } else { FactorEdge edge; if (!messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge)) { // TODO: skip methods which aren't message operators } else { string field = edge.ToString(); Type type = parameter.ParameterType; if (edge.IsOutgoingMessage) { bool isFresh = parameter.IsDefined(typeof(FreshAttribute), false); writer.WriteFormatString("{0} message to ", isFresh ? "Outgoing" : "Previous outgoing"); writer.WriteElementFormatString("c", parameter.Name.Substring(3)); writer.WriteString("."); } else if (factorInfo.ParameterTypes.ContainsKey(field)) { if (type == factorInfo.ParameterTypes[field]) { writer.WriteString("Constant value for "); writer.WriteElementFormatString("c", field); writer.WriteString("."); } else { writer.WriteString("Incoming message from "); writer.WriteElementFormatString("c", field); writer.WriteString("."); foreach (FactorEdge rrange in messageFunctionInfo.Requirements) { FactorEdge range = rrange; if (range.ParameterName == field) { writer.WriteString(" Must be a proper distribution. If "); if (Util.IsIList(factorInfo.ParameterTypes[field])) { foreach (FactorEdge range2 in messageFunctionInfo.Dependencies) { if (range2.ParameterName == field) { range = range.Intersect(range2); } } if (range.MinCount == 1) { writer.WriteString(range.ContainsIndex ? "the element at resultIndex is " : "all elements are "); } else if (range.ContainsAllOthers) { writer.WriteString(range.ContainsIndex ? "any element is " : "any element besides resultIndex is "); } } writer.WriteString("uniform, the result will be uniform."); } } } } else { writer.WriteString("Buffer "); writer.WriteElementFormatString("c", parameter.Name); writer.WriteString("."); } } } writer.WriteEndElement(); }
/// <summary> /// Writes the remarks section of the XML documentation for a given message operator using a given XML writer. /// </summary> /// <param name="writer">The XML writer.</param> /// <param name="factorInfo">The factor the message operator is for.</param> /// <param name="messageFunctionInfo">The message operator.</param> private static void WriteMessageOperatorRemarks( XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo) { bool isConstraint = factorInfo.Method.ReturnType == typeof(void); // must use the parameter names of the factor, not the operator method, because the operator method may not have parameters for all of the factor edges string argsString = StringUtil.CollectionToString(factorInfo.ParameterNames, ","); string childString = isConstraint ? string.Empty : factorInfo.ParameterNames[0]; bool childIsRandom = false; ParameterInfo[] parameters = messageFunctionInfo.Method.GetParameters(); string randomFieldsString = string.Empty; string randomFieldsMinusTarget = string.Empty; string randomFieldsMinusTargetAndChild = string.Empty; foreach (ParameterInfo parameter in parameters) { if (parameter.Name != "result" && parameter.Name != "resultIndex") { FactorEdge edge; if (messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge)) { string field = edge.ToString(); Type type = parameter.ParameterType; if (!edge.IsOutgoingMessage && factorInfo.ParameterTypes.ContainsKey(field) && type != factorInfo.ParameterTypes[field]) { if (!isConstraint && factorInfo.ParameterNames[0] == field) { childIsRandom = true; } randomFieldsString += randomFieldsString == string.Empty ? field : "," + field; if (field != messageFunctionInfo.TargetParameter) { randomFieldsMinusTarget += randomFieldsMinusTarget == string.Empty ? field : "," + field; if (factorInfo.ParameterNames[0] != field) { randomFieldsMinusTargetAndChild += randomFieldsMinusTargetAndChild == string.Empty ? field : "," + field; } } } } } } writer.WriteStartElement("remarks"); writer.WriteStartElement("para"); if (messageFunctionInfo.Method.Name == "AverageLogFactor") { if (randomFieldsString == string.Empty) { writer.WriteString("The formula for the result is "); writer.WriteElementFormatString("c", "log(factor({0}))", argsString); } else { if (factorInfo.IsDeterministicFactor) { writer.WriteString("In Variational Message Passing, the evidence contribution of a deterministic factor is zero"); } else { writer.WriteString("The formula for the result is "); writer.WriteElementFormatString("c", "sum_({0}) p({0}) log(factor({1}))", randomFieldsString, argsString); } } writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for VMP."); } else if (messageFunctionInfo.Method.Name == "LogAverageFactor") { writer.WriteString("The formula for the result is "); if (randomFieldsString == string.Empty) { writer.WriteElementFormatString("c", "log(factor({0}))", argsString); } else { writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString); } writer.WriteString("."); } else if (messageFunctionInfo.Method.Name == "LogEvidenceRatio") { writer.WriteString("The formula for the result is "); if (randomFieldsString == string.Empty) { writer.WriteElementFormatString("c", "log(factor({0}))", argsString); } else { if (childIsRandom) { writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}) / sum_{2} p({2}) messageTo({2}))", randomFieldsString, argsString, childString); } else { writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString); } } writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for EP."); } else if (messageFunctionInfo.Suffix == "Conditional") { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" conditioned on the given values."); } else if (messageFunctionInfo.Suffix == "AverageConditional") { if (randomFieldsMinusTarget == string.Empty) { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" conditioned on the given values."); } else { writer.WriteString("The outgoing message is a distribution matching the moments of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" as the random arguments are varied. The formula is "); writer.WriteElementFormatString("c", "proj[p({2}) sum_({1}) p({1}) factor({0})]/p({2})", argsString, randomFieldsMinusTarget, messageFunctionInfo.TargetParameter); writer.WriteString("."); } } else if (messageFunctionInfo.Suffix == "AverageLogarithm") { if (randomFieldsMinusTarget == string.Empty) { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" conditioned on the given values."); } else { if (factorInfo.IsDeterministicFactor) { if (childString == messageFunctionInfo.TargetParameter) { writer.WriteString("The outgoing message is a distribution matching the moments of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" as the random arguments are varied. The formula is "); writer.WriteElementFormatString("c", "proj[sum_({1}) p({1}) factor({0})]", argsString, randomFieldsMinusTarget); } else { if (childIsRandom && randomFieldsMinusTargetAndChild == string.Empty) { writer.WriteString("The outgoing message is the factor viewed as a function of "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(" with "); writer.WriteElementFormatString("c", childString); writer.WriteString(" integrated out. The formula is "); writer.WriteElementFormatString("c", "sum_{1} p({1}) factor({0})", argsString, childString); } else { writer.WriteString("The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(". "); if (childIsRandom) { writer.WriteString("Because the factor is deterministic, "); writer.WriteElementFormatString("c", childString); writer.WriteString(" is integrated out before taking the logarithm. The formula is "); writer.WriteElementFormatString( "c", "exp(sum_({1}) p({1}) log(sum_{2} p({2}) factor({0})))", argsString, randomFieldsMinusTargetAndChild, childString); } else { writer.WriteString("The formula is "); writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget); } } } } else { writer.WriteString( "The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except "); writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter); writer.WriteString(". The formula is "); writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget); } writer.WriteString("."); } } writer.WriteEndElement(); writer.WriteEndElement(); }