Ejemplo n.º 1
0
        protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            if (rhs is IArrayCreateExpression)
            {
                IArrayCreateExpression iace = (IArrayCreateExpression)rhs;
                bool zeroLength             = iace.Dimensions.All(dimExpr =>
                                                                  (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0));
                if (!zeroLength && iace.Initializer == null)
                {
                    return; // variable will have assignments to elements
                }
            }
            bool firstTime = !variablesAssigned.Contains(ivd);

            variablesAssigned.Add(ivd);
            if (!CodeRecognizer.IsStochastic(context, ivd))
            {
                return;
            }
            // definition of a stochastic variable
            VariableInformation vi  = VariableInformation.GetVariableInformation(context, ivd);
            IExpression         lhs = target;

            if (lhs is IVariableDeclarationExpression)
            {
                lhs = Builder.VarRefExpr(ivd);
            }
            IExpression defExpr       = lhs;
            Containers  defContainers = context.InputAttributes.Get <Containers>(ivd);
            int         ancIndex      = defContainers.GetMatchingAncestorIndex(context);
            Containers  missing       = defContainers.GetContainersNotInContext(context, ancIndex);

            if (firstTime)
            {
                // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform
                ChannelInfo defChannel = ChannelInfo.DefChannel(vi);
                defChannel.decl = ivd;
                context.OutputAttributes.Set(ivd, defChannel);
                SetMarginalPrototype(ivd);
            }
            bool       isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
            IAlgorithm algorithm = this.algorithmDefault;
            Algorithm  algAttr   = context.InputAttributes.Get <Algorithm>(ivd);

            if (algAttr != null)
            {
                algorithm = algAttr.algorithm;
            }
            if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime)
            {
                vi.DefineAllIndexVars(context);
                IList <IStatement>   stmts     = Builder.StmtCollection();
                IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv");
                context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl));
                ChannelInfo derivChannel = ChannelInfo.DefChannel(vi);
                derivChannel.decl = derivDecl;
                context.OutputAttributes.Set(derivChannel.decl, derivChannel);
                context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'"));
                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            bool isInferred      = context.InputAttributes.Has <IsInferred>(ivd);
            bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd);

            if (this.analysis.variablesExcludingVariableFactor.Contains(ivd))
            {
                this.variablesLackingVariableFactor.Add(ivd);
                return;
            }
            if (isDerived && !isInferred && !isPointEstimate)
            {
                return;
            }

            IExpression useExpr2 = null;

            if (firstTime)
            {
                // create marginal and use channels
                vi.DefineAllIndexVars(context);
                IList <IStatement> stmts = Builder.StmtCollection();

                CreateMarginalChannel(ivd, vi, stmts);
                CreateUseChannel(ivd, vi, stmts);

                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
                context.InputAttributes.Set(useOfVariable[ivd], defContainers);
            }
            if (!useOfVariable.ContainsKey(ivd))
            {
                Error("cannot find use channel of " + ivd);
                return;
            }
            IExpression  useExpr      = Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]);
            IExpression  marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]);
            InitialiseTo it           = context.InputAttributes.Get <InitialiseTo>(ivd);

            Type[] genArgs = new Type[] { defExpr.GetExpressionType() };
            if (rhs is IMethodInvokeExpression)
            {
                IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Factor.Copy)) && ancIndex < context.InputStack.Count - 2)
                {
                    IExpression          arg  = imie.Arguments[0];
                    IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg);
                    if (context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2))
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = arg;
                        shouldDelete = true;
                        bool makeClone = false;
                        if (makeClone)
                        {
                            VariableInformation         vi2      = VariableInformation.GetVariableInformation(context, ivd2);
                            IList <IStatement>          stmts    = Builder.StmtCollection();
                            List <IList <IExpression> > indices  = Recognizer.GetIndices(defExpr);
                            IVariableDeclaration        useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices);
                            useExpr2 = Builder.VarRefExpr(useDecl2);
                            Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2);
                            int        ancIndex2      = defContainers2.GetMatchingAncestorIndex(context);
                            Containers missing2       = defContainers2.GetContainersNotInContext(context, ancIndex2);
                            stmts = Containers.WrapWithContainers(stmts, missing2.outputs);
                            context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts);
                            context.InputAttributes.Set(useDecl2, defContainers2);

                            // TODO: call CreateUseChannel
                            ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2);
                            usageChannel.decl = useDecl2;
                            context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.OutputAttributes.Set(useDecl2, usageChannel);
                            //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'"));
                            context.OutputAttributes.Remove <InitialiseTo>(vi.declaration);
                            SetMarginalPrototype(useDecl2);

                            IExpression copyExpr = Builder.StaticGenericMethod(
                                new Func <PlaceHolder, PlaceHolder>(Factor.Copy), genArgs, useExpr2);
                            var copyStmt = Builder.AssignStmt(useExpr, copyExpr);
                            context.AddStatementAfterCurrent(copyStmt);
                        }
                    }
                }
            }

            // Add the variable factor
            IExpression variableFactorExpr;
            bool        isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd);

            if (isGateExitRandom)
            {
                variableFactorExpr = Builder.StaticGenericMethod(
                    new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable),
                    genArgs, defExpr, marginalExpr);
            }
            else
            {
                Delegate d = algorithm.GetVariableFactor(isDerived, it != null);
                if (isPointEstimate)
                {
                    d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Factor.VariablePoint);
                }
                if (it == null)
                {
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr);
                }
                else
                {
                    IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression);
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr);
                }
            }
            context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr);
            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr);
            context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor());
            var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr);

            context.AddStatementAfterCurrent(assignStmt);
        }