示例#1
0
        public void NotFactorInfo()
        {
            FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Factor), "Not").GetMethodInfo());
            Console.WriteLine(info);
            Assert.True(info.IsDeterministicFactor);
            IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>();

            parameterTypes["Not"]    = typeof(Bernoulli);
            parameterTypes["B"]      = typeof(Bernoulli);
            parameterTypes["result"] = typeof(Bernoulli);
            MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Not", parameterTypes);

            Assert.False(fcninfo.PassResult);
            Assert.False(fcninfo.PassResultIndex);
            Assert.Equal(1, fcninfo.Dependencies.Count);
            Assert.Equal(1, fcninfo.Requirements.Count);
            Console.WriteLine(fcninfo);
        }
示例#2
0
        public void MissingMethodFailure()
        {
            FactorManager.FactorInfo  info           = FactorManager.GetFactorInfo(new Func <double, double, double>(Factor.Gaussian));
            Dictionary <string, Type> parameterTypes = new Dictionary <string, Type>();

            parameterTypes["sample"]    = typeof(Gaussian);
            parameterTypes["mean"]      = typeof(Gaussian);
            parameterTypes["precision"] = typeof(Gaussian);
            try
            {
                MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Sample", parameterTypes);
                Assert.True(false, "Did not throw an exception");
            }
            catch (ArgumentException ex)
            {
                Console.WriteLine("Correctly failed with exception: " + ex);
            }
        }
示例#3
0
        public void TypeConstraintFailureFactorInfo()
        {
            FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(ShiftAlpha), "ToFactor<>").GetMethodInfo());
            Console.WriteLine(info);
            Dictionary <string, Type> parameterTypes = new Dictionary <string, Type>();

            parameterTypes["factor"] = typeof(double);
            parameterTypes["result"] = typeof(double);
            try
            {
                MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Variable", parameterTypes);
                Assert.True(false, "Did not throw an exception");
            }
            catch (ArgumentException ex)
            {
                Console.WriteLine("Correctly failed with exception: " + ex);
            }
        }
示例#4
0
        //[Fact]
        // should take 140ms
        // turn on PrintStatistics in Binding.cs
        internal void SpeedTest()
        {
            FactorManager.FactorInfo   info           = FactorManager.GetFactorInfo(new MethodReference(typeof(Factor), "ReplicateWithMarginal<>", typeof(bool[])).GetMethodInfo());
            IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>();
            Type ba  = typeof(DistributionStructArray <Bernoulli, bool>);
            Type baa = typeof(DistributionRefArray <DistributionStructArray <Bernoulli, bool>, bool[]>);

            parameterTypes["Uses"]     = baa;
            parameterTypes["Def"]      = ba;
            parameterTypes["Marginal"] = ba;
            parameterTypes["result"]   = baa;
            MessageFcnInfo fcninfo;

            fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Uses", parameterTypes);

            parameterTypes["result"]      = ba;
            parameterTypes["resultIndex"] = typeof(int);
            fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Uses", parameterTypes);
        }
示例#5
0
        internal void CheckMessageFcnInfo(MessageFcnInfo fcninfo, FactorManager.FactorInfo info)
        {
            Assert.True(fcninfo.Suffix != null);
            Assert.True(fcninfo.TargetParameter != null);
            Dictionary <string, Type> parameterTypes = new Dictionary <string, Type>();

            foreach (KeyValuePair <string, Type> parameter in fcninfo.GetParameterTypes())
            {
                parameterTypes[parameter.Key] = parameter.Value;
            }
            try
            {
                MessageFcnInfo fcninfo2 = info.GetMessageFcnInfo(factorManager, fcninfo.Suffix, fcninfo.TargetParameter, parameterTypes);
                Assert.Equal(fcninfo2.Method, fcninfo.Method);
            }
            catch (NotSupportedException)
            {
                Assert.True(fcninfo.NotSupportedMessage != null);
            }
        }
示例#6
0
        protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
        {
            CheckMethodArgumentCount(imie);
            if (CodeRecognizer.IsInfer(imie))
            {
                inferCount++;
                object decl = Recognizer.GetDeclaration(imie.Arguments[0]);
                if (decl != null && !context.InputAttributes.Has <IsInferred>(decl))
                {
                    context.InputAttributes.Set(decl, new IsInferred());
                }
                // the arguments must not be substituted for their values, so we don't call ConvertExpression
                var newArgs = imie.Arguments.Select(CodeRecognizer.RemoveCast);
                IMethodInvokeExpression infer = Builder.MethodInvkExpr();
                infer.Method = imie.Method;
                infer.Arguments.AddRange(newArgs);
                context.InputAttributes.CopyObjectAttributesTo(imie, context.OutputAttributes, infer);
                return(infer);
            }
            IExpression converted = base.ConvertMethodInvoke(imie);

            if (converted is IMethodInvokeExpression mie)
            {
                foreach (IExpression arg in mie.Arguments)
                {
                    if (arg is IAddressOutExpression iaoe)
                    {
                        IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iaoe.Expression);
                        if (ivd != null)
                        {
                            FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, mie);
                            if (info != null && info.IsDeterministicFactor && !context.InputAttributes.Has <DerivedVariable>(ivd))
                            {
                                context.InputAttributes.Set(ivd, new DerivedVariable());
                            }
                        }
                    }
                }
            }
            return(converted);
        }
 /// <summary>
 /// Attach DerivedVariable attributes to newly created variables
 /// </summary>
 /// <param name="iae"></param>
 /// <returns></returns>
 protected override IExpression ConvertAssign(IAssignExpression iae)
 {
     iae = (IAssignExpression)base.ConvertAssign(iae);
     if (iae.Expression is IMethodInvokeExpression imie)
     {
         IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target);
         if (ivd != null)
         {
             bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
             if (!isDerived)
             {
                 FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie);
                 if (info.IsDeterministicFactor)
                 {
                     context.InputAttributes.Set(ivd, new DerivedVariable());
                 }
             }
         }
     }
     return(iae);
 }
        public void ReplicateMessageFcnInfo()
        {
            FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Clone), "Replicate<>", typeof(bool)).GetMethodInfo());
            var parameterTypes            = new Dictionary <string, Type>
            {
                ["Uses"]   = typeof(DistributionArray <Bernoulli>),
                ["Def"]    = typeof(Bernoulli),
                ["result"] = typeof(Bernoulli) //typeof(DistributionArray<Bernoulli>);
            };
            MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Def", parameterTypes);

            CheckMessageFcnInfo(fcninfo, info);

            parameterTypes["Uses"] = typeof(Bernoulli[]);
            fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Def", parameterTypes);
            CheckMessageFcnInfo(fcninfo, info);

            DependencyInformation depInfo;

            depInfo = FactorManager.GetDependencyInfo(fcninfo.Method);
        }
示例#9
0
        public void IsPositiveFactorInfo()
        {
            FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new Func <double, bool>(Factor.IsPositive));
            Console.WriteLine(info);
            Assert.True(info.IsDeterministicFactor);
            IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>();

            parameterTypes["isPositive"] = typeof(bool);
            parameterTypes["x"]          = typeof(Gaussian);
            MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "x", parameterTypes);

            Assert.True(fcninfo.NotSupportedMessage == null);
            CheckMessageFcnInfo(fcninfo, info);

            DependencyInformation depInfo = FactorManager.GetDependencyInfo(fcninfo.Method);

            Console.WriteLine("Dependencies:");
            Console.WriteLine(StringUtil.ToString(depInfo.Dependencies));
            Console.WriteLine("Requirements:");
            Console.WriteLine(StringUtil.ToString(depInfo.Requirements));

            Assert.Throws <NotSupportedException>(() =>
            {
                fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "x", parameterTypes);
            });
            bool found = false;

            foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos("AverageLogarithm", null, null))
            {
                CheckMessageFcnInfo(fcninfo2, info);
                if (fcninfo2.TargetParameter.Equals("x"))
                {
                    Assert.True(fcninfo2.NotSupportedMessage != null);
                    found = true;
                }
            }
            Assert.True(found);
            fcninfo = info.GetMessageFcnInfo(factorManager, "LogAverageFactor", "", parameterTypes);
            CheckMessageFcnInfo(fcninfo, info);
        }
示例#10
0
        internal static MessageFcnInfo GetMessageFcnInfo(BasicTransformContext context, IMethodInvokeExpression imie)
        {
            MethodInfo     method  = (MethodInfo)imie.Method.Method.MethodInfo;
            MessageFcnInfo fcnInfo = context.InputAttributes.Get <MessageFcnInfo>(method);

            if (fcnInfo == null)
            {
                FactorManager.FactorInfo info = context.InputAttributes.Get <FactorManager.FactorInfo>(imie);
                if (info != null)
                {
                    fcnInfo = info.GetMessageFcnInfoFromFactor();
                }
                else
                {
                    ParameterInfo[] parameters = method.GetParameters();
                    fcnInfo = new MessageFcnInfo(method, parameters);
                    fcnInfo.DependencyInfo = FactorManager.GetDependencyInfo(method);
                }
                //context.InputAttributes.Set(method, fcnInfo);
            }
            return(fcnInfo);
        }
示例#11
0
        /// <summary>
        /// Writes the XML documentation for a given message operator using a given XML writer.
        /// </summary>
        /// <param name="writer">The XML writer.</param>
        /// <param name="factorInfo">The factor the message operator is for.</param>
        /// <param name="messageFunctionInfo">The message operator.</param>
        private static void WriteMessageFunctionDocumentation(
            XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo)
        {
            string methodName = QuoteCodeElementName(StringUtil.MethodSignatureToString(
                                                         messageFunctionInfo.Method, useFullName: false, omitParameterNames: true));

            writer.WriteStartElement("message_doc");
            writer.WriteAttributeString("name", methodName);

            WriteMessageOperatorSummary(writer, messageFunctionInfo);

            foreach (ParameterInfo parameter in messageFunctionInfo.Method.GetParameters())
            {
                WriteMessageOperatorParameterDescription(writer, factorInfo, messageFunctionInfo, parameter);
            }

            WriteMessageOperatorReturns(writer, factorInfo, messageFunctionInfo);
            WriteMessageOperatorRemarks(writer, factorInfo, messageFunctionInfo);
            WriteMessageOperatorExceptionSpec(writer, factorInfo, messageFunctionInfo);

            writer.WriteEndElement();
        }
        public void ConstrainEqualRandomFactorInfo()
        {
            FactorManager.FactorInfo info =
                FactorManager.GetFactorInfo(new MethodReference(typeof(Constrain), "EqualRandom<,>", typeof(bool), typeof(Bernoulli)).GetMethodInfo());
            Assert.Equal(2, info.ParameterNames.Count);
            var parameterTypes = new Dictionary <string, Type>
            {
                ["value"]  = typeof(Bernoulli),
                ["dist"]   = typeof(Bernoulli),
                ["result"] = typeof(Bernoulli)
            };
            MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "value", parameterTypes);

            Assert.False(fcninfo.PassResult);
            Assert.False(fcninfo.PassResultIndex);
            Assert.Equal(1, fcninfo.Dependencies.Count);
            Assert.Equal(1, fcninfo.Requirements.Count);

            bool verbose = false;

            if (verbose)
            {
                Console.WriteLine("All MessageFcnInfos:");
            }
            int count = 0;

            foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos())
            {
                if (verbose)
                {
                    Console.WriteLine(fcninfo2);
                }
                CheckMessageFcnInfo(fcninfo2, info);
                count++;
            }
            //Assert.Equal(4, count);
        }
        public void UsesEqualDefFactorInfo()
        {
            DependencyInformation depInfo;

            FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Clone), "UsesEqualDef<>", typeof(bool)).GetMethodInfo());
            Assert.True(!info.IsDeterministicFactor);
            var parameterTypes = new Dictionary <string, Type>
            {
                ["Uses"]        = typeof(Bernoulli[]),
                ["Def"]         = typeof(Bernoulli),
                ["resultIndex"] = typeof(int),
                ["result"]      = typeof(Bernoulli)
            };

            for (int i = 0; i < 3; i++)
            {
                MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Uses", parameterTypes);
                Assert.True(fcninfo.PassResult);
                Assert.True(fcninfo.PassResultIndex);
                Assert.Equal(2, fcninfo.Dependencies.Count);
                Assert.Equal(1, fcninfo.Requirements.Count);
                Assert.True(fcninfo.SkipIfAllUniform);
                Assert.Equal(1, fcninfo.Triggers.Count);

                if (i == 0)
                {
                    depInfo = FactorManager.GetDependencyInfo(fcninfo.Method);
                }
            }
            parameterTypes.Remove("resultIndex");
            MessageFcnInfo fcninfo2 = info.GetMessageFcnInfo(factorManager, "AverageLogarithm", "Def", parameterTypes);

            Assert.True(fcninfo2.SkipIfAllUniform);
            Assert.Equal(1, fcninfo2.Triggers.Count);

            depInfo = FactorManager.GetDependencyInfo(fcninfo2.Method);
        }
示例#14
0
        private MessageFcnInfo GetMessageFcnInfo(FactorManager.FactorInfo info, string methodSuffix, string targetParameter, IDictionary <string, Type> parameterTypes)
        {
            List <string> factoryParameters = new List <string>();

            while (true)
            {
                MessageFcnInfo  fcninfo              = info.GetMessageFcnInfo(factorManager, methodSuffix, targetParameter, parameterTypes);
                ParameterInfo[] parameters           = fcninfo.Method.GetParameters();
                bool            newFactoryParameters = false;
                foreach (ParameterInfo parameter in parameters)
                {
                    if (parameterTypes.ContainsKey(parameter.Name))
                    {
                        continue;
                    }
                    if (IsFactoryType(parameter.ParameterType))
                    {
                        newFactoryParameters = true;
                        Type[] typeArgs    = parameter.ParameterType.GetGenericArguments();
                        Type   itemType    = typeArgs[0];
                        Type   arrayType   = Distribution.MakeDistributionArrayType(itemType, 1);
                        Type   factoryType = typeof(IArrayFactory <,>).MakeGenericType(itemType, arrayType);
                        parameterTypes[parameter.Name] = factoryType;
                        factoryParameters.Add(parameter.Name);
                    }
                }
                if (newFactoryParameters)
                {
                    continue;
                }
                foreach (string factoryParameter in factoryParameters)
                {
                    parameterTypes.Remove(factoryParameter);
                }
                return(fcninfo);
            }
        }
示例#15
0
        public void DifferenceFactorInfo()
        {
            FactorManager.FactorInfo info = FactorManager.GetFactorInfo(new MethodReference(typeof(Factor), "Difference").GetMethodInfo());
            Console.WriteLine(info);
            Assert.True(info.IsDeterministicFactor);
            IDictionary <string, Type> parameterTypes = new Dictionary <string, Type>();

            parameterTypes["Difference"] = typeof(Gaussian);
            parameterTypes["A"]          = typeof(Gaussian);
            parameterTypes["B"]          = typeof(Gaussian);
            //parameterTypes["result"] = typeof(Gaussian);
            MessageFcnInfo fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Difference", parameterTypes);

            Assert.False(fcninfo.PassResult);
            Assert.False(fcninfo.PassResultIndex);
            Assert.Equal(2, fcninfo.Dependencies.Count);
            Assert.Equal(2, fcninfo.Requirements.Count);
            Console.WriteLine(fcninfo);

            Console.WriteLine("Parameter types:");
            Console.WriteLine(StringUtil.CollectionToString(fcninfo.GetParameterTypes(), ","));

            Console.WriteLine();
            try
            {
                fcninfo = info.GetMessageFcnInfo(factorManager, "Rubbish", "Difference", parameterTypes);
                Assert.True(false, "Did not throw an exception");
            }
            catch (ArgumentException ex)
            {
                if (!ex.Message.Contains("MissingMethodException"))
                {
                    Assert.True(false, "Correctly threw exception, but with wrong message");
                }
                Console.WriteLine("Different exception: " + ex);
            }

            Console.WriteLine();
            try
            {
                parameterTypes["result"] = typeof(double);
                fcninfo = info.GetMessageFcnInfo(factorManager, "AverageConditional", "Difference", parameterTypes);
                Assert.True(false, "Did not throw an exception");
            }
            catch (ArgumentException ex)
            {
                Console.WriteLine("Correctly failed with exception: " + ex);
            }

            Console.WriteLine();
            Console.WriteLine("All messages to A:");
            int count = 0;

            foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos(null, "A", null))
            {
                Console.WriteLine(fcninfo2);
                CheckMessageFcnInfo(fcninfo2, info);
                count++;
            }
            //Assert.Equal(8, count);

            Console.WriteLine();
            Console.WriteLine("All messages to Difference:");
            count = 0;
            foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos(null, "Difference", null))
            {
                Console.WriteLine(fcninfo2);
                CheckMessageFcnInfo(fcninfo2, info);
                count++;
            }
            Assert.Equal(8, count);

            Console.WriteLine();
            Console.WriteLine("All AverageConditionals:");
            count = 0;
            foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos("AverageConditional", null, null))
            {
                Console.WriteLine(fcninfo2);
                CheckMessageFcnInfo(fcninfo2, info);
                count++;
            }
            Assert.Equal(12, count);

            Console.WriteLine();
            Console.WriteLine("All MessageFcnInfos:");
            count = 0;
            foreach (MessageFcnInfo fcninfo2 in info.GetMessageFcnInfos())
            {
                Console.WriteLine(fcninfo2);
                CheckMessageFcnInfo(fcninfo2, info);
                count++;
            }
            //Assert.Equal(26, count);
        }
示例#16
0
        protected void ProcessDefinition(IExpression expr, IVariableDeclaration targetVar, bool isLhs)
        {
            bool targetIsPointMass = false;

            if (expr is IMethodInvokeExpression imie)
            {
                // TODO: consider using a method attribute for this
                if (Recognizer.IsStaticGenericMethod(imie, new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint))
                    )
                {
                    targetIsPointMass = true;
                }
                else
                {
                    FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie);
                    targetIsPointMass = info.IsDeterministicFactor && (
                        (info.ReturnedInAllElementsParameterIndex != -1 && ArgumentIsPointMass(imie.Arguments[info.ReturnedInAllElementsParameterIndex])) ||
                        imie.Arguments.All(ArgumentIsPointMass)
                        );
                }
                if (targetIsPointMass)
                {
                    // do this immediately so all uses are updated
                    if (!context.InputAttributes.Has <ForwardPointMass>(targetVar))
                    {
                        context.OutputAttributes.Set(targetVar, new ForwardPointMass());
                    }
                    // the rest is done later
                    List <IMethodInvokeExpression> list;
                    if (!variablesDefinedPointMass.TryGetValue(targetVar, out list))
                    {
                        list = new List <IMethodInvokeExpression>();
                        variablesDefinedPointMass.Add(targetVar, list);
                    }
                    // this code needs to be synchronized with MessageTransform.ConvertMethodInvoke
                    if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, int, PlaceHolder[]>(Clone.Replicate)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged))
                        )
                    {
                        list.Add(imie);
                    }
                }
            }
            if (!targetIsPointMass && !(expr is IArrayCreateExpression))
            {
                variablesDefinedNonPointMass.Add(targetVar);
                if (variablesDefinedPointMass.ContainsKey(targetVar))
                {
                    variablesDefinedPointMass.Remove(targetVar);
                    context.OutputAttributes.Remove <ForwardPointMass>(targetVar);
                }
            }

            bool ArgumentIsPointMass(IExpression arg)
            {
                bool IsOut = (arg is IAddressOutExpression);

                if (CodeRecognizer.IsStochastic(context, arg) && !IsOut)
                {
                    IVariableDeclaration argVar = Recognizer.GetVariableDeclaration(arg);
                    return((argVar != null) && context.InputAttributes.Has <ForwardPointMass>(argVar));
                }
                else
                {
                    return(true);
                }
            }
        }
示例#17
0
        protected override IExpression ConvertAssign(IAssignExpression iae)
        {
            IParameterDeclaration ipd = null;
            IVariableDeclaration  ivd = Recognizer.GetVariableDeclaration(iae.Target);
            object decl = ivd;

            if (ivd == null)
            {
                ipd = Recognizer.GetParameterDeclaration(iae.Target);
                if (ipd == null)
                {
                    return(base.ConvertAssign(iae));
                }
                decl = ipd;
            }
            if (iae.Target is IArrayIndexerExpression)
            {
                // Gather index variables from the left-hand side of the assignment
                VariableInformation vi = VariableInformation.GetVariableInformation(context, decl);
                try
                {
                    List <IVariableDeclaration[]> indVars = new List <IVariableDeclaration[]>();
                    Recognizer.AddIndexers(context, indVars, iae.Target);
                    int depth = Recognizer.GetIndexingDepth(iae.Target);
                    // if this statement is actually a constraint, then we don't need to enforce matching of index variables
                    bool isConstraint = context.InputAttributes.Has <Models.Constraint>(context.FindAncestor <IStatement>());
                    for (int i = 0; i < depth; i++)
                    {
                        vi.SetIndexVariablesAtDepth(i, indVars[i], allowMismatch: isConstraint);
                    }
                }
                catch (Exception ex)
                {
                    Error(ex.Message, ex);
                }
            }
            IAssignExpression ae = (IAssignExpression)base.ConvertAssign(iae);

            if (ipd == null)
            {
                // assignment to a local variable
                if (ae.Expression is IMethodInvokeExpression)
                {
                    IMethodInvokeExpression imie = (IMethodInvokeExpression)ae.Expression;
                    // this unfortunately duplicates some of the work done by SetStoch and IsStoch.
                    FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie);
                    if (info != null && info.IsDeterministicFactor && !context.InputAttributes.Has <DerivedVariable>(ivd))
                    {
                        context.InputAttributes.Set(ivd, new DerivedVariable());
                    }
                }
                if (ae.Expression is ILiteralExpression)
                {
                    bool isLoopInitializer = (Recognizer.GetAncestorIndexOfLoopBeingInitialized(context) != -1);
                    if (!isLoopInitializer)
                    {
                        Type valueType = ae.Expression.GetExpressionType();
                        if (Quoter.ShouldInlineType(valueType))
                        {
                            // inline all future occurrences of this variable with the rhs expression
                            conditionContext.Add(new ConditionBinding(ae.Target, ae.Expression));
                        }
                    }
                }
            }
            else
            {
                // assignment to a method parameter
                IStatement ist = context.FindAncestor <IStatement>();
                if (!context.InputAttributes.Has <Models.Constraint>(ist))
                {
                    // mark this statement as a constraint
                    context.OutputAttributes.Set(ist, new Models.Constraint());
                }
            }
            // a FactorAlgorithm attribute on a variable turns into an Algorithm attribute on its right hand side.
            var attr = context.InputAttributes.Get <FactorAlgorithm>(decl);

            if (attr != null)
            {
                context.OutputAttributes.Set(ae.Expression, new Algorithm(attr.algorithm));
            }
            context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(decl, context.OutputAttributes, ae.Expression);
            return(ae);
        }
示例#18
0
        protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
        {
            if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, ICompilerAttribute, PlaceHolder>(Attrib.Var)))
            {
                IVariableReferenceExpression ivre   = imie.Arguments[0] as IVariableReferenceExpression;
                IVariableDeclaration         target = ivre.Variable.Resolve();
                IExpression expr = CodeRecognizer.RemoveCast(imie.Arguments[1]);
                AddAttribute(target, expr);
                return(null);
            }
            else if (Recognizer.IsStaticMethod(imie, new Action <object, object>(Attrib.InitialiseTo)))
            {
                IVariableReferenceExpression ivre   = CodeRecognizer.RemoveCast(imie.Arguments[0]) as IVariableReferenceExpression;
                IVariableDeclaration         target = ivre.Variable.Resolve();
                context.OutputAttributes.Set(target, new InitialiseTo(imie.Arguments[1]));
                return(null);
            }
            else if (CodeRecognizer.IsInfer(imie))
            {
                inferCount++;
                object decl = Recognizer.GetDeclaration(imie.Arguments[0]);
                if (decl != null && !context.InputAttributes.Has <IsInferred>(decl))
                {
                    context.InputAttributes.Set(decl, new IsInferred());
                }
                // the arguments must not be substituted for their values, so we don't call ConvertExpression
                List <IExpression> newArgs = new List <IExpression>();
                foreach (var arg in imie.Arguments)
                {
                    newArgs.Add(CodeRecognizer.RemoveCast(arg));
                }
                IMethodInvokeExpression mie = Builder.MethodInvkExpr();
                mie.Method = imie.Method;
                mie.Arguments.AddRange(newArgs);
                context.InputAttributes.CopyObjectAttributesTo(imie, context.OutputAttributes, mie);
                return(mie);
            }
            IExpression converted = base.ConvertMethodInvoke(imie);

            if (converted is IMethodInvokeExpression)
            {
                var  mie   = (IMethodInvokeExpression)converted;
                bool isAnd = Recognizer.IsStaticMethod(converted, new Func <bool, bool, bool>(Factors.Factor.And));
                bool isOr  = Recognizer.IsStaticMethod(converted, new Func <bool, bool, bool>(Factors.Factor.Or));
                bool anyArgumentIsLiteral = mie.Arguments.Any(arg => arg is ILiteralExpression);
                if (anyArgumentIsLiteral)
                {
                    if (isAnd)
                    {
                        if (mie.Arguments.Any(arg => arg is ILiteralExpression && ((ILiteralExpression)arg).Value.Equals(false)))
                        {
                            return(Builder.LiteralExpr(false));
                        }
                        // any remaining literals must be true, and therefore can be ignored.
                        var reducedArguments = mie.Arguments.Where(arg => !(arg is ILiteralExpression));
                        if (reducedArguments.Count() == 1)
                        {
                            return(reducedArguments.First());
                        }
                        else
                        {
                            return(Builder.LiteralExpr(true));
                        }
                    }
                    else if (isOr)
                    {
                        if (mie.Arguments.Any(arg => arg is ILiteralExpression && ((ILiteralExpression)arg).Value.Equals(true)))
                        {
                            return(Builder.LiteralExpr(true));
                        }
                        // any remaining literals must be false, and therefore can be ignored.
                        var reducedArguments = mie.Arguments.Where(arg => !(arg is ILiteralExpression));
                        if (reducedArguments.Count() == 1)
                        {
                            return(reducedArguments.First());
                        }
                        else
                        {
                            return(Builder.LiteralExpr(false));
                        }
                    }
                    else if (Recognizer.IsStaticMethod(converted, new Func <bool, bool>(Factors.Factor.Not)))
                    {
                        bool allArgumentsAreLiteral = mie.Arguments.All(arg => arg is ILiteralExpression);
                        if (allArgumentsAreLiteral)
                        {
                            return(Builder.LiteralExpr(evaluator.Evaluate(mie)));
                        }
                    }
                }
                foreach (IExpression arg in mie.Arguments)
                {
                    if (arg is IAddressOutExpression)
                    {
                        IAddressOutExpression iaoe = (IAddressOutExpression)arg;
                        IVariableDeclaration  ivd  = Recognizer.GetVariableDeclaration(iaoe.Expression);
                        if (ivd != null)
                        {
                            FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, mie);
                            if (info != null && info.IsDeterministicFactor && !context.InputAttributes.Has <DerivedVariable>(ivd))
                            {
                                context.InputAttributes.Set(ivd, new DerivedVariable());
                            }
                        }
                    }
                }
            }
            return(converted);
        }
示例#19
0
 /// <summary>
 /// Initializes a new instance of the <see cref="FactorInfoWrapper"/> class.
 /// </summary>
 /// <param name="factorInfo">The wrapped factor info.</param>
 public FactorInfoWrapper(FactorManager.FactorInfo factorInfo)
 {
     Debug.Assert(factorInfo != null, "The given factor info cannot be null.");
     this.FactorInfo = factorInfo;
 }
示例#20
0
        /// <summary>
        /// Writes the XML documentation for a parameter of given message operator using a given XML writer.
        /// </summary>
        /// <param name="writer">The XML writer.</param>
        /// <param name="factorInfo">The factor the message operator is for.</param>
        /// <param name="messageFunctionInfo">The message operator.</param>
        /// <param name="parameter">The parameter.</param>
        private static void WriteMessageOperatorParameterDescription(
            XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo, ParameterInfo parameter)
        {
            writer.WriteStartElement("param");
            writer.WriteAttributeString("name", parameter.Name);

            if (parameter.Name == "result")
            {
                writer.WriteString("Modified to contain the outgoing message.");
            }
            else if (parameter.Name == "resultIndex")
            {
                writer.WriteString("Index of the ");
                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                writer.WriteString(" for which a message is desired.");
            }
            else
            {
                FactorEdge edge;
                if (!messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge))
                {
                    // TODO: skip methods which aren't message operators
                }
                else
                {
                    string field = edge.ToString();
                    Type   type  = parameter.ParameterType;
                    if (edge.IsOutgoingMessage)
                    {
                        bool isFresh = parameter.IsDefined(typeof(FreshAttribute), false);
                        writer.WriteFormatString("{0} message to ", isFresh ? "Outgoing" : "Previous outgoing");
                        writer.WriteElementFormatString("c", parameter.Name.Substring(3));
                        writer.WriteString(".");
                    }
                    else if (factorInfo.ParameterTypes.ContainsKey(field))
                    {
                        if (type == factorInfo.ParameterTypes[field])
                        {
                            writer.WriteString("Constant value for ");
                            writer.WriteElementFormatString("c", field);
                            writer.WriteString(".");
                        }
                        else
                        {
                            writer.WriteString("Incoming message from ");
                            writer.WriteElementFormatString("c", field);
                            writer.WriteString(".");
                            foreach (FactorEdge rrange in messageFunctionInfo.Requirements)
                            {
                                FactorEdge range = rrange;
                                if (range.ParameterName == field)
                                {
                                    writer.WriteString(" Must be a proper distribution. If ");
                                    if (Util.IsIList(factorInfo.ParameterTypes[field]))
                                    {
                                        foreach (FactorEdge range2 in messageFunctionInfo.Dependencies)
                                        {
                                            if (range2.ParameterName == field)
                                            {
                                                range = range.Intersect(range2);
                                            }
                                        }

                                        if (range.MinCount == 1)
                                        {
                                            writer.WriteString(range.ContainsIndex ? "the element at resultIndex is " : "all elements are ");
                                        }
                                        else if (range.ContainsAllOthers)
                                        {
                                            writer.WriteString(range.ContainsIndex ? "any element is " : "any element besides resultIndex is ");
                                        }
                                    }

                                    writer.WriteString("uniform, the result will be uniform.");
                                }
                            }
                        }
                    }
                    else
                    {
                        writer.WriteString("Buffer ");
                        writer.WriteElementFormatString("c", parameter.Name);
                        writer.WriteString(".");
                    }
                }
            }

            writer.WriteEndElement();
        }
示例#21
0
        /// <summary>
        /// Writes the remarks section of the XML documentation for a given message operator using a given XML writer.
        /// </summary>
        /// <param name="writer">The XML writer.</param>
        /// <param name="factorInfo">The factor the message operator is for.</param>
        /// <param name="messageFunctionInfo">The message operator.</param>
        private static void WriteMessageOperatorRemarks(
            XmlWriter writer, FactorManager.FactorInfo factorInfo, MessageFcnInfo messageFunctionInfo)
        {
            bool isConstraint = factorInfo.Method.ReturnType == typeof(void);

            // must use the parameter names of the factor, not the operator method, because the operator method may not have parameters for all of the factor edges
            string argsString  = StringUtil.CollectionToString(factorInfo.ParameterNames, ",");
            string childString = isConstraint ? string.Empty : factorInfo.ParameterNames[0];

            bool childIsRandom = false;

            ParameterInfo[] parameters                      = messageFunctionInfo.Method.GetParameters();
            string          randomFieldsString              = string.Empty;
            string          randomFieldsMinusTarget         = string.Empty;
            string          randomFieldsMinusTargetAndChild = string.Empty;

            foreach (ParameterInfo parameter in parameters)
            {
                if (parameter.Name != "result" && parameter.Name != "resultIndex")
                {
                    FactorEdge edge;
                    if (messageFunctionInfo.factorEdgeOfParameter.TryGetValue(parameter.Name, out edge))
                    {
                        string field = edge.ToString();
                        Type   type  = parameter.ParameterType;
                        if (!edge.IsOutgoingMessage && factorInfo.ParameterTypes.ContainsKey(field) && type != factorInfo.ParameterTypes[field])
                        {
                            if (!isConstraint && factorInfo.ParameterNames[0] == field)
                            {
                                childIsRandom = true;
                            }

                            randomFieldsString += randomFieldsString == string.Empty ? field : "," + field;
                            if (field != messageFunctionInfo.TargetParameter)
                            {
                                randomFieldsMinusTarget += randomFieldsMinusTarget == string.Empty ? field : "," + field;
                                if (factorInfo.ParameterNames[0] != field)
                                {
                                    randomFieldsMinusTargetAndChild += randomFieldsMinusTargetAndChild == string.Empty ? field : "," + field;
                                }
                            }
                        }
                    }
                }
            }

            writer.WriteStartElement("remarks");
            writer.WriteStartElement("para");
            if (messageFunctionInfo.Method.Name == "AverageLogFactor")
            {
                if (randomFieldsString == string.Empty)
                {
                    writer.WriteString("The formula for the result is ");
                    writer.WriteElementFormatString("c", "log(factor({0}))", argsString);
                }
                else
                {
                    if (factorInfo.IsDeterministicFactor)
                    {
                        writer.WriteString("In Variational Message Passing, the evidence contribution of a deterministic factor is zero");
                    }
                    else
                    {
                        writer.WriteString("The formula for the result is ");
                        writer.WriteElementFormatString("c", "sum_({0}) p({0}) log(factor({1}))", randomFieldsString, argsString);
                    }
                }

                writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for VMP.");
            }
            else if (messageFunctionInfo.Method.Name == "LogAverageFactor")
            {
                writer.WriteString("The formula for the result is ");

                if (randomFieldsString == string.Empty)
                {
                    writer.WriteElementFormatString("c", "log(factor({0}))", argsString);
                }
                else
                {
                    writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString);
                }

                writer.WriteString(".");
            }
            else if (messageFunctionInfo.Method.Name == "LogEvidenceRatio")
            {
                writer.WriteString("The formula for the result is ");

                if (randomFieldsString == string.Empty)
                {
                    writer.WriteElementFormatString("c", "log(factor({0}))", argsString);
                }
                else
                {
                    if (childIsRandom)
                    {
                        writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}) / sum_{2} p({2}) messageTo({2}))", randomFieldsString, argsString, childString);
                    }
                    else
                    {
                        writer.WriteElementFormatString("c", "log(sum_({0}) p({0}) factor({1}))", randomFieldsString, argsString);
                    }
                }

                writer.WriteString(". Adding up these values across all factors and variables gives the log-evidence estimate for EP.");
            }
            else if (messageFunctionInfo.Suffix == "Conditional")
            {
                writer.WriteString("The outgoing message is the factor viewed as a function of ");
                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                writer.WriteString(" conditioned on the given values.");
            }
            else if (messageFunctionInfo.Suffix == "AverageConditional")
            {
                if (randomFieldsMinusTarget == string.Empty)
                {
                    writer.WriteString("The outgoing message is the factor viewed as a function of ");
                    writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                    writer.WriteString(" conditioned on the given values.");
                }
                else
                {
                    writer.WriteString("The outgoing message is a distribution matching the moments of ");
                    writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                    writer.WriteString(" as the random arguments are varied. The formula is ");
                    writer.WriteElementFormatString("c", "proj[p({2}) sum_({1}) p({1}) factor({0})]/p({2})", argsString, randomFieldsMinusTarget, messageFunctionInfo.TargetParameter);
                    writer.WriteString(".");
                }
            }
            else if (messageFunctionInfo.Suffix == "AverageLogarithm")
            {
                if (randomFieldsMinusTarget == string.Empty)
                {
                    writer.WriteString("The outgoing message is the factor viewed as a function of ");
                    writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                    writer.WriteString(" conditioned on the given values.");
                }
                else
                {
                    if (factorInfo.IsDeterministicFactor)
                    {
                        if (childString == messageFunctionInfo.TargetParameter)
                        {
                            writer.WriteString("The outgoing message is a distribution matching the moments of ");
                            writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                            writer.WriteString(" as the random arguments are varied. The formula is ");
                            writer.WriteElementFormatString("c", "proj[sum_({1}) p({1}) factor({0})]", argsString, randomFieldsMinusTarget);
                        }
                        else
                        {
                            if (childIsRandom && randomFieldsMinusTargetAndChild == string.Empty)
                            {
                                writer.WriteString("The outgoing message is the factor viewed as a function of ");
                                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                                writer.WriteString(" with ");
                                writer.WriteElementFormatString("c", childString);
                                writer.WriteString(" integrated out. The formula is ");
                                writer.WriteElementFormatString("c", "sum_{1} p({1}) factor({0})", argsString, childString);
                            }
                            else
                            {
                                writer.WriteString("The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except ");
                                writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                                writer.WriteString(". ");

                                if (childIsRandom)
                                {
                                    writer.WriteString("Because the factor is deterministic, ");
                                    writer.WriteElementFormatString("c", childString);
                                    writer.WriteString(" is integrated out before taking the logarithm. The formula is ");
                                    writer.WriteElementFormatString(
                                        "c", "exp(sum_({1}) p({1}) log(sum_{2} p({2}) factor({0})))", argsString, randomFieldsMinusTargetAndChild, childString);
                                }
                                else
                                {
                                    writer.WriteString("The formula is ");
                                    writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget);
                                }
                            }
                        }
                    }
                    else
                    {
                        writer.WriteString(
                            "The outgoing message is the exponential of the average log-factor value, where the average is over all arguments except ");
                        writer.WriteElementFormatString("c", messageFunctionInfo.TargetParameter);
                        writer.WriteString(". The formula is ");
                        writer.WriteElementFormatString("c", "exp(sum_({1}) p({1}) log(factor({0})))", argsString, randomFieldsMinusTarget);
                    }

                    writer.WriteString(".");
                }
            }

            writer.WriteEndElement();
            writer.WriteEndElement();
        }