コード例 #1
0
        internal static Containers SortStochasticConditionals(Containers containers, BasicTransformContext context)
        {
            Containers result       = new Containers();
            Containers conditionals = new Containers();

            for (int i = 0; i < containers.inputs.Count; i++)
            {
                IStatement container = containers.inputs[i];
                if (container is IConditionStatement)
                {
                    IConditionStatement ics = (IConditionStatement)container;
                    if (CodeRecognizer.IsStochastic(context, ics.Condition))
                    {
                        conditionals.inputs.Add(container);
                        conditionals.outputs.Add(containers.outputs[i]);
                        continue;
                    }
                }
                result.inputs.Add(container);
                result.outputs.Add(containers.outputs[i]);
            }
            for (int i = 0; i < conditionals.inputs.Count; i++)
            {
                result.inputs.Add(conditionals.inputs[i]);
                result.outputs.Add(conditionals.outputs[i]);
            }
            return(result);
        }
コード例 #2
0
        protected override IExpression ConvertVariableDeclExpr(IVariableDeclarationExpression ivde)
        {
            context.InputAttributes.Remove <Containers>(ivde.Variable);
            context.InputAttributes.Set(ivde.Variable, new Containers(context));
            IVariableDeclaration ivd = ivde.Variable;

            if (!CodeRecognizer.IsStochastic(context, ivd))
            {
                ProcessConstant(ivd);
                if (!context.InputAttributes.Has <DescriptionAttribute>(ivd))
                {
                    context.OutputAttributes.Set(ivd, new DescriptionAttribute("The constant '" + ivd.Name + "'"));
                }
                return(ivde);
            }
            VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd);
            // Ensure the marginal prototype is set.
            MarginalPrototype mpa = Context.InputAttributes.Get <MarginalPrototype>(ivd);

            try
            {
                vi.SetMarginalPrototypeFromAttribute(mpa);
            }
            catch (ArgumentException ex)
            {
                Error(ex.Message);
            }
            return(ivde);
        }
コード例 #3
0
        protected void ProcessDefinition(IExpression expr, IVariableDeclaration targetVar, bool isLhs)
        {
            bool targetIsPointMass       = false;
            IMethodInvokeExpression imie = expr as IMethodInvokeExpression;

            if (imie != null)
            {
                // TODO: consider using a method attribute for this
                if (Recognizer.IsStaticGenericMethod(imie, new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint))
                    )
                {
                    targetIsPointMass = true;
                }
                else
                {
                    FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie);
                    targetIsPointMass = info.IsDeterministicFactor && (
                        (info.ReturnedInAllElementsParameterIndex != -1 && ArgumentIsPointMass(imie.Arguments[info.ReturnedInAllElementsParameterIndex])) ||
                        imie.Arguments.All(ArgumentIsPointMass)
                        );
                }
                if (targetIsPointMass)
                {
                    // do this immediately so all uses are updated
                    if (!context.InputAttributes.Has <ForwardPointMass>(targetVar))
                    {
                        context.OutputAttributes.Set(targetVar, new ForwardPointMass());
                    }
                    // the rest is done later
                    List <IMethodInvokeExpression> list;
                    if (!variablesDefinedPointMass.TryGetValue(targetVar, out list))
                    {
                        list = new List <IMethodInvokeExpression>();
                        variablesDefinedPointMass.Add(targetVar, list);
                    }
                    // this code needs to be synchronized with MessageTransform.ConvertMethodInvoke
                    if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, int, PlaceHolder[]>(Clone.Replicate)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged))
                        )
                    {
                        list.Add(imie);
                    }
                }
            }
            if (!targetIsPointMass && !(expr is IArrayCreateExpression))
            {
                variablesDefinedNonPointMass.Add(targetVar);
                if (variablesDefinedPointMass.ContainsKey(targetVar))
                {
                    variablesDefinedPointMass.Remove(targetVar);
                    context.OutputAttributes.Remove <ForwardPointMass>(targetVar);
                }
            }
        }
コード例 #4
0
 protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
 {
     if (CodeRecognizer.IsInfer(imie))
     {
         return(ConvertInfer(imie));
     }
     return(base.ConvertMethodInvoke(imie));
 }
コード例 #5
0
 private List <ConditionBinding> FilterConditionContext(GateBlock gateBlock, List <ConditionBinding> conditionContext)
 {
     return(conditionContext
            .Where(binding => !CodeRecognizer.IsStochastic(context, binding.lhs) &&
                   !ContainsLocalVars(gateBlock, binding.lhs) &&
                   !ContainsLocalVars(gateBlock, binding.rhs))
            .ToList());
 }
コード例 #6
0
        protected override IStatement ConvertExpressionStatement(IExpressionStatement ies)
        {
            if (parent == null)
            {
                return(ies);
            }
            bool keepIfStatement = false;
            // Only keep the surrounding if statement when a factor or constraint is being added.
            IExpression expr = ies.Expression;

            if (expr is IMethodInvokeExpression)
            {
                keepIfStatement = true;
                if (CodeRecognizer.IsInfer(expr))
                {
                    keepIfStatement = false;
                }
            }
            else if (expr is IAssignExpression)
            {
                keepIfStatement = false;
                IAssignExpression       iae  = (IAssignExpression)expr;
                IMethodInvokeExpression imie = iae.Expression as IMethodInvokeExpression;
                if (imie != null)
                {
                    keepIfStatement = true;
                    if (imie.Arguments.Count > 0)
                    {
                        // Statements that copy evidence variables should not send evidence messages.
                        IVariableDeclaration ivd    = Recognizer.GetVariableDeclaration(iae.Target);
                        IVariableDeclaration ivdArg = Recognizer.GetVariableDeclaration(imie.Arguments[0]);
                        if (ivd != null && context.InputAttributes.Has <DoNotSendEvidence>(ivd) &&
                            ivdArg != null && context.InputAttributes.Has <DoNotSendEvidence>(ivdArg))
                        {
                            keepIfStatement = false;
                        }
                    }
                }
                else
                {
                    expr = iae.Target;
                }
            }
            if (expr is IVariableDeclarationExpression)
            {
                IVariableDeclarationExpression ivde = (IVariableDeclarationExpression)expr;
                IVariableDeclaration           ivd  = ivde.Variable;
                keepIfStatement = CodeRecognizer.IsStochastic(context, ivd) && !context.InputAttributes.Has <DoNotSendEvidence>(ivd);
            }
            if (!keepIfStatement)
            {
                return(ies);
            }
            IConditionStatement cs = Builder.CondStmt(parent.Condition, Builder.BlockStmt());

            cs.Then.Statements.Add(ies);
            return(cs);
        }
コード例 #7
0
 protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
 {
     // do not count argument of Infer as a use
     if (CodeRecognizer.IsInfer(imie))
     {
         return(imie);
     }
     return(base.ConvertMethodInvoke(imie));
 }
コード例 #8
0
 /// <summary>
 /// Returns true if ivd is constant for a given value of the loop variables in lc
 /// </summary>
 /// <param name="ivd"></param>
 /// <param name="constantLoopVars"></param>
 /// <returns></returns>
 private bool IsConstantWrtLoops(IVariableDeclaration ivd, Set<IVariableDeclaration> constantLoopVars)
 {
     if (ivd == null) return true;
     if (CodeRecognizer.IsStochastic(context, ivd)) return false;
     if (constantLoopVars.Contains(ivd)) return true;
     LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
     // return false if lc2 has any loops that are not in constantLoopVars
     return constantLoopVars.ContainsAll(lc2.loopVariables);
 }
コード例 #9
0
 protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
 {
     if (CodeRecognizer.IsInfer(imie))
     {
         // ignore the variable use
         return(imie);
     }
     return(base.ConvertMethodInvoke(imie));
 }
コード例 #10
0
 protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
 {
     if (CodeRecognizer.IsInfer(imie))
     {
         return(imie);
     }
     CheckForDuplicateArguments(imie);
     return(base.ConvertMethodInvoke(imie));
 }
コード例 #11
0
        private bool ArgumentIsPointMass(IExpression arg)
        {
            bool IsOut = (arg is IAddressOutExpression);

            if (CodeRecognizer.IsStochastic(context, arg) && !IsOut)
            {
                IVariableDeclaration argVar = Recognizer.GetVariableDeclaration(arg);
                return((argVar != null) && context.InputAttributes.Has <ForwardPointMass>(argVar));
            }
            else
            {
                return(true);
            }
        }
コード例 #12
0
 /// <summary>
 /// Raise an error if any expression is stochastic.
 /// </summary>
 /// <param name="exprs"></param>
 private void CheckIndicesAreNotStochastic(IList <IExpression> exprs)
 {
     foreach (IExpression index in exprs)
     {
         foreach (var ivd in Recognizer.GetVariables(index))
         {
             if (CodeRecognizer.IsStochastic(context, ivd))
             {
                 string msg = "Indexing by a random variable '" + ivd.Name + "'.  You must wrap this statement with Variable.Switch(" + index + ")";
                 Error(msg);
             }
         }
     }
 }
コード例 #13
0
        /// <summary>
        /// Convert random assignments to derived variables
        /// </summary>
        /// <param name="iae"></param>
        /// <returns></returns>
        protected override IExpression ConvertAssign(IAssignExpression iae)
        {
            iae = (IAssignExpression)base.ConvertAssign(iae);
            IStatement ist = context.FindAncestor <IStatement>();

            if (!context.InputAttributes.Has <Models.Constraint>(ist) &&
                (iae.Expression is IMethodInvokeExpression))
            {
                IMethodInvokeExpression imie = (IMethodInvokeExpression)iae.Expression;
                IVariableDeclaration    ivd  = Recognizer.GetVariableDeclaration(iae.Target);
                if (ivd != null)
                {
                    bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
                    if (isDerived)
                    {
                        FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie);
                        if (!info.IsDeterministicFactor)
                        {
                            // The variable is derived, but this definition is not derived.
                            // Thus we must convert
                            //   y = sample()
                            // into
                            //   y_random = sample()
                            //   y = Copy(y_random)
                            // where y_random is not derived.
                            VariableInformation varInfo = VariableInformation.GetVariableInformation(context, ivd);
                            IList <IStatement>  stmts   = Builder.StmtCollection();
                            string name = ivd.Name + "_random" + (Count++);
                            List <IList <IExpression> > indices  = Recognizer.GetIndices(iae.Target);
                            IVariableDeclaration        cloneVar = varInfo.DeriveIndexedVariable(stmts, context, name, indices, copyInitializer: true);
                            context.OutputAttributes.Remove <DerivedVariable>(cloneVar);
                            stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(cloneVar), iae.Expression));
                            int ancIndex = context.FindAncestorIndex <IStatement>();
                            context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
                            Type tp = iae.Target.GetExpressionType();
                            if (tp == null)
                            {
                                Error("Could not determine type of expression: " + iae.Target);
                                return(iae);
                            }
                            IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp },
                                                                           Builder.VarRefExpr(cloneVar));
                            iae = Builder.AssignExpr(iae.Target, copy);
                        }
                    }
                }
            }
            return(iae);
        }
コード例 #14
0
 protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
 {
     if (CodeRecognizer.IsIsIncreasing(imie))
     {
         IExpression          arg = imie.Arguments[0];
         IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(arg);
         if (backwardLoops.Contains(ivd.Name))
         {
             return(Builder.LiteralExpr(false));
         }
         else
         {
             return(Builder.LiteralExpr(true));
         }
     }
     return(base.ConvertMethodInvoke(imie));
 }
コード例 #15
0
        internal static List <ConditionBinding> GetBindings(BasicTransformContext context, IEnumerable <IStatement> containers)
        {
            List <ConditionBinding> bindings = new List <ConditionBinding>();

            foreach (IStatement st in containers)
            {
                if (st is IConditionStatement)
                {
                    IConditionStatement ics = (IConditionStatement)st;
                    if (!CodeRecognizer.IsStochastic(context, ics.Condition))
                    {
                        ConditionBinding binding = new ConditionBinding(ics.Condition);
                        bindings.Add(binding);
                    }
                }
            }
            return(bindings);
        }
コード例 #16
0
 protected override IStatement ConvertExpressionStatement(IExpressionStatement ies)
 {
     if (ies.Expression is IAssignExpression iae)
     {
         if (iae.Expression is IMethodInvokeExpression)
         {
             factorExprs.Add(iae);
         }
     }
     else if (ies.Expression is IMethodInvokeExpression imie)
     {
         if (CodeRecognizer.IsInfer(imie))
         {
             return(ies);
         }
         factorExprs.Add(ies.Expression);
     }
     return(ies);
 }
コード例 #17
0
        protected void ProcessExpression(IExpression expr)
        {
            bool isDef = Recognizer.IsBeingMutated(context, expr);
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr);

            if (ivd == null)
            {
                return;
            }
            if (!CodeRecognizer.IsStochastic(context, ivd))
            {
                return;
            }
            if (isDef && !Recognizer.IsBeingAllocated(context, expr))
            {
                RegisterDefinition(ivd);
            }
            ProcessUse(expr, isDef, conditionContext);
        }
コード例 #18
0
        private NodeInfo GetNodeInfo(IExpression factor)
        {
            IExpression             target = null;
            IMethodInvokeExpression imie;

            if (factor is IAssignExpression iae)
            {
                target = iae.Target;
                if (target is IVariableDeclarationExpression)
                {
                    IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);
                    target = Builder.VarRefExpr(ivd);
                }
                imie = (IMethodInvokeExpression)iae.Expression;
            }
            else
            {
                imie = (IMethodInvokeExpression)factor;
            }
            NodeInfo info = new NodeInfo(imie)
            {
                info = CodeRecognizer.GetFactorInfo(context, imie)
            };

            if (target != null)
            {
                info.isReturnOrOut.Add(true);
                info.arguments.Add(target);
            }
            if (!info.info.Method.IsStatic)
            {
                info.isReturnOrOut.Add(false);
                info.arguments.Add(imie.Method.Target);
            }
            foreach (IExpression arg in imie.Arguments)
            {
                bool isOut = (arg is IAddressOutExpression);
                info.isReturnOrOut.Add(isOut);
                info.arguments.Add(isOut ? ((IAddressOutExpression)arg).Expression : arg);
            }
            return(info);
        }
コード例 #19
0
ファイル: VariableTransform.cs プロジェクト: zouwuhe/infer
        protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
        {
            if (CodeRecognizer.IsInfer(imie))
            {
                return(ConvertInfer(imie));
            }
            foreach (IExpression arg in imie.Arguments)
            {
                if (arg is IAddressOutExpression)
                {
                    IAddressOutExpression iaoe = (IAddressOutExpression)arg;
                    targetsOfCurrentAssignment.Add(iaoe.Expression);
                }
            }
            if (Recognizer.IsStaticGenericMethod(imie, new Func <IList <PlaceHolder>, int[][], PlaceHolder[][]>(Factor.JaggedSubarray)))
            {
                IExpression          arrayExpr = imie.Arguments[0];
                IVariableDeclaration ivd       = Recognizer.GetVariableDeclaration(imie.Arguments[0]);
                if (ivd != null && (arrayExpr is IVariableReferenceExpression) && this.variablesLackingVariableFactor.Contains(ivd) &&
                    !marginalOfVariable.ContainsKey(ivd))
                {
                    VariableInformation vi    = VariableInformation.GetVariableInformation(context, ivd);
                    IList <IStatement>  stmts = Builder.StmtCollection();
                    CreateMarginalChannel(ivd, vi, stmts);

                    Containers defContainers = context.InputAttributes.Get <Containers>(ivd);
                    int        ancIndex      = defContainers.GetMatchingAncestorIndex(context);
                    Containers missing       = defContainers.GetContainersNotInContext(context, ancIndex);
                    stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                    context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);

                    // none of the arguments should need to be transformed
                    IExpression             indicesExpr  = imie.Arguments[1];
                    IExpression             marginalExpr = Builder.VarRefExpr(marginalOfVariable[ivd]);
                    IMethodInvokeExpression mie          = Builder.StaticGenericMethod(new Models.FuncOut <IList <PlaceHolder>, int[][], IList <PlaceHolder>, PlaceHolder[][]>(Factor.JaggedSubarrayWithMarginal),
                                                                                       new Type[] { Utilities.Util.GetElementType(arrayExpr.GetExpressionType()) },
                                                                                       arrayExpr, indicesExpr, marginalExpr);
                    return(mie);
                }
            }
            return(base.ConvertMethodInvoke(imie));
        }
コード例 #20
0
        protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
        {
            CheckMethodArgumentCount(imie);
            if (CodeRecognizer.IsInfer(imie))
            {
                inferCount++;
                object decl = Recognizer.GetDeclaration(imie.Arguments[0]);
                if (decl != null && !context.InputAttributes.Has <IsInferred>(decl))
                {
                    context.InputAttributes.Set(decl, new IsInferred());
                }
                // the arguments must not be substituted for their values, so we don't call ConvertExpression
                var newArgs = imie.Arguments.Select(CodeRecognizer.RemoveCast);
                IMethodInvokeExpression infer = Builder.MethodInvkExpr();
                infer.Method = imie.Method;
                infer.Arguments.AddRange(newArgs);
                context.InputAttributes.CopyObjectAttributesTo(imie, context.OutputAttributes, infer);
                return(infer);
            }
            IExpression converted = base.ConvertMethodInvoke(imie);

            if (converted is IMethodInvokeExpression mie)
            {
                foreach (IExpression arg in mie.Arguments)
                {
                    if (arg is IAddressOutExpression iaoe)
                    {
                        IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iaoe.Expression);
                        if (ivd != null)
                        {
                            FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, mie);
                            if (info != null && info.IsDeterministicFactor && !context.InputAttributes.Has <DerivedVariable>(ivd))
                            {
                                context.InputAttributes.Set(ivd, new DerivedVariable());
                            }
                        }
                    }
                }
            }
            return(converted);
        }
コード例 #21
0
 /// <summary>
 /// Attach DerivedVariable attributes to newly created variables
 /// </summary>
 /// <param name="iae"></param>
 /// <returns></returns>
 protected override IExpression ConvertAssign(IAssignExpression iae)
 {
     iae = (IAssignExpression)base.ConvertAssign(iae);
     if (iae.Expression is IMethodInvokeExpression imie)
     {
         IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target);
         if (ivd != null)
         {
             bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
             if (!isDerived)
             {
                 FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie);
                 if (info.IsDeterministicFactor)
                 {
                     context.InputAttributes.Set(ivd, new DerivedVariable());
                 }
             }
         }
     }
     return(iae);
 }
コード例 #22
0
        /// <summary>
        /// Analyse the condition body using an augmented conditionContext
        /// </summary>
        /// <param name="ics"></param>
        /// <returns></returns>
        protected override IStatement ConvertCondition(IConditionStatement ics)
        {
            if (CodeRecognizer.IsStochastic(context, ics.Condition))
            {
                return(base.ConvertCondition(ics));
            }
            // ics.Condition is not stochastic
            context.SetPrimaryOutput(ics);
            ConvertExpression(ics.Condition);
            ConditionBinding binding = new ConditionBinding(ics.Condition);
            int startIndex           = conditionContext.Count;

            conditionContext.Add(binding);
            ConvertBlock(ics.Then);
            if (ics.Else != null)
            {
                conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
                binding = binding.FlipCondition();
                conditionContext.Add(binding);
                ConvertBlock(ics.Else);
            }
            conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
            return(ics);
        }
コード例 #23
0
        protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie)
        {
            bool isDef = Recognizer.IsBeingMutated(context, iaie);

            if (isDef)
            {
                // do not clone the lhs of an array create assignment.
                IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>();
                if (assignExpr.Expression is IArrayCreateExpression)
                {
                    return(iaie);
                }
            }
            base.ConvertArrayIndexer(iaie);
            IndexInfo info;

            // TODO: Instead of storing an IndexInfo for each distinct expression, we should try to unify expressions, as in GateAnalysisTransform.
            // For example, we could unify a[0,i] and a[0,0] and use the same clone array for both.
            if (indexInfoOf.TryGetValue(iaie, out info))
            {
                Containers containers = new Containers(context);
                if (info.bindings.Count > 0)
                {
                    List <ConditionBinding> bindings = GetBindings(context, containers.inputs);
                    if (bindings.Count == 0)
                    {
                        info.bindings.Clear();
                    }
                    else
                    {
                        info.bindings.Add(bindings);
                    }
                }
                info.containers = Containers.Intersect(info.containers, containers);
                info.count++;
                if (isDef)
                {
                    info.IsAssignedTo = true;
                }
                return(iaie);
            }
            CheckIndicesAreNotStochastic(iaie.Indices);
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(iaie);

            // If not an indexed variable reference, skip it (e.g. an indexed argument reference)
            if (baseVar == null)
            {
                return(iaie);
            }
            // If the variable is not stochastic, skip it
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return(iaie);
            }
            // If the indices are all loop variables, skip it
            var  indices        = Recognizer.GetIndices(iaie);
            bool allLoopIndices = indices.All(bracket => bracket.All(indexExpr =>
            {
                if (indexExpr is IVariableReferenceExpression)
                {
                    IVariableReferenceExpression ivre = (IVariableReferenceExpression)indexExpr;
                    return(Recognizer.GetLoopForVariable(context, ivre) != null);
                }
                else
                {
                    return(false);
                }
            }));

            if (allLoopIndices)
            {
                return(iaie);
            }

            info            = new IndexInfo();
            info.containers = new Containers(context);
            List <ConditionBinding> bindings2 = GetBindings(context, info.containers.inputs);

            if (bindings2.Count > 0)
            {
                info.bindings.Add(bindings2);
            }
            info.count        = 1;
            info.IsAssignedTo = isDef;
            indexInfoOf[iaie] = info;
            return(iaie);
        }
コード例 #24
0
        /// <summary>
        /// Adds a Replicate statement to the output code, if needed.
        /// </summary>
        /// <param name="target">LHS of assignment</param>
        /// <param name="rhs">RHS of assignment</param>
        /// <param name="shouldDelete">True if the original assignment statement should be deleted, i.e. it can be optimized away.</param>
        protected void AddReplicateStatement(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd);

            if (!vi.IsStochastic)
            {
                ProcessConstant(ivd);
                return;
            }
            bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(ivd);

            if (SwapIndices && replicateEvidenceVars != isEvidenceVar)
            {
                return;
            }
            // iae defines a stochastic variable
            ChannelAnalysisTransform.UsageInfo usageInfo;
            if (!analysis.usageInfo.TryGetValue(ivd, out usageInfo))
            {
                return;
            }
            int useCount = usageInfo.NumberOfUses;

            if (useCount <= 1)
            {
                return;
            }

            VariableToChannelInformation vtci;
            bool firstTime   = !usesOfVariable.TryGetValue(ivd, out vtci);
            int  targetDepth = Recognizer.GetIndexingDepth(target);
            int  minDepth    = usageInfo.indexingDepths[0];
            int  usageDepth  = minDepth;

            if (!SwapIndices)
            {
                usageDepth = 0;
            }
            Containers defContainers = context.InputAttributes.Get <Containers>(ivd);
            int        ancIndex      = defContainers.GetMatchingAncestorIndex(context);
            Containers missing       = defContainers.GetContainersNotInContext(context, ancIndex);

            if (firstTime)
            {
                // declaration of uses array
                IList <IStatement> stmts = Builder.StmtCollection();
                vtci  = DeclareUsesArray(stmts, ivd, vi, useCount, usageDepth);
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            // check that extra literal indices in the target are zero.
            // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i])
            // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1])
            // but not x_uses[i] = Rep(x[i]) since this will be a duplicate.
            bool extraLiteralsAreZero = CheckExtraLiteralsAreZero(target, targetDepth, usageDepth);

            if (extraLiteralsAreZero)
            {
                // definition of uses array
                IExpression       defExpr  = Builder.VarRefExpr(ivd);
                IExpression       usesExpr = Builder.VarRefExpr(vtci.usesDecl);
                List <IStatement> loops    = new List <IStatement>();
                if (usageDepth == targetDepth)
                {
                    defExpr = target;
                    if (defExpr is IVariableDeclarationExpression)
                    {
                        defExpr = Builder.VarRefExpr(ivd);
                    }
                    usesExpr = Builder.ReplaceVariable(defExpr, ivd, vtci.usesDecl);
                }
                else
                {
                    // loops over the last indexing bracket
                    for (int d = 0; d < usageDepth; d++)
                    {
                        List <IExpression> indices = new List <IExpression>();
                        for (int i = 0; i < vi.sizes[d].Length; i++)
                        {
                            IVariableDeclaration v    = vi.indexVars[d][i];
                            IStatement           loop = Builder.ForStmt(v, vi.sizes[d][i]);
                            loops.Add(loop);
                            indices.Add(Builder.VarRefExpr(v));
                        }
                        defExpr  = Builder.ArrayIndex(defExpr, indices);
                        usesExpr = Builder.ArrayIndex(usesExpr, indices);
                    }
                }

                if (rhs != null && rhs is IMethodInvokeExpression)
                {
                    IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                    bool copyPropagation         = false;
                    if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Factor.Copy)) && copyPropagation)
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = imie.Arguments[0];
                        shouldDelete = true;
                    }
                }

                // Add the statement:
                //   x_uses = Replicate(x, useCount)
                var genArgs = new Type[] { ivd.VariableType.DotNetType };
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func <PlaceHolder, int, PlaceHolder[]>(Factor.Replicate),
                    genArgs, defExpr, Builder.LiteralExpr(useCount));
                bool isGateExitRandom = context.InputAttributes.Has <Algorithms.VariationalMessagePassing.GateExitRandomVariable>(ivd);
                if (isGateExitRandom)
                {
                    repMethod = Builder.StaticGenericMethod(
                        new Func <PlaceHolder, int, PlaceHolder[]>(Gate.ReplicateExiting),
                        genArgs, defExpr, Builder.LiteralExpr(useCount));
                }
                context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, repMethod);
                if (context.InputAttributes.Has <DivideMessages>(ivd))
                {
                    context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, repMethod);
                }
                else if (useCount == 2)
                {
                    // division has no benefit for 2 uses, and degrades the schedule
                    context.OutputAttributes.Set(repMethod, new DivideMessages(false));
                }
                context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, repMethod);
                IStatement repSt = Builder.AssignStmt(usesExpr, repMethod);
                if (usageDepth == targetDepth)
                {
                    if (isEvidenceVar)
                    {
                        // place the Replicate after the current statement, but outside of any evidence conditionals
                        ancIndex = context.FindAncestorIndex <IStatement>();
                        for (int i = ancIndex - 2; i > 0; i--)
                        {
                            object ancestor = context.GetAncestor(i);
                            if (ancestor is IConditionStatement)
                            {
                                IConditionStatement ics = (IConditionStatement)ancestor;
                                if (CodeRecognizer.IsStochastic(context, ics.Condition))
                                {
                                    ancIndex = i;
                                    break;
                                }
                            }
                        }
                        context.AddStatementAfterAncestorIndex(ancIndex, repSt);
                    }
                    else
                    {
                        context.AddStatementAfterCurrent(repSt);
                    }
                }
                else if (firstTime)
                {
                    // place Replicate after assignment but outside of definition loops (can't have any uses there)
                    repSt = Containers.WrapWithContainers(repSt, loops);
                    repSt = Containers.WrapWithContainers(repSt, missing.outputs);
                    context.AddStatementAfterAncestorIndex(ancIndex, repSt);
                }
            }
        }
コード例 #25
0
        /// <summary>
        /// This method does all the work of converting literal indexing expressions.
        /// </summary>
        /// <param name="iaie"></param>
        /// <returns></returns>
        protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie)
        {
            IndexAnalysisTransform.IndexInfo info;
            if (!analysis.indexInfoOf.TryGetValue(iaie, out info))
            {
                return(base.ConvertArrayIndexer(iaie));
            }
            // Determine if this is a definition i.e. the variable is on the left hand side of an assignment
            // This must be done before base.ConvertArrayIndexer changes the expression!
            bool isDef = Recognizer.IsBeingMutated(context, iaie);

            if (info.clone != null)
            {
                if (isDef)
                {
                    // check that extra literal indices in the target are zero.
                    // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i])
                    // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1])
                    // but not x_uses[i] = Rep(x[i]) since this will be a duplicate.
                    bool   extraLiteralsAreZero = true;
                    int    parentIndex          = context.InputStack.Count - 2;
                    object parent = context.GetAncestor(parentIndex);
                    while (parent is IArrayIndexerExpression)
                    {
                        IArrayIndexerExpression parent_iaie = (IArrayIndexerExpression)parent;
                        foreach (IExpression index in parent_iaie.Indices)
                        {
                            if (index is ILiteralExpression)
                            {
                                int value = (int)((ILiteralExpression)index).Value;
                                if (value != 0)
                                {
                                    extraLiteralsAreZero = false;
                                    break;
                                }
                            }
                        }
                        parentIndex--;
                        parent = context.GetAncestor(parentIndex);
                    }
                    if (false && extraLiteralsAreZero)
                    {
                        // change:
                        //   array[0] = f()
                        // into:
                        //   array_item0 = f()
                        //   array[0] = Copy(array_item0)
                        IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { iaie.GetExpressionType() },
                                                                       info.clone);
                        IStatement copySt = Builder.AssignStmt(iaie, copy);
                        context.AddStatementAfterCurrent(copySt);
                    }
                }
                return(info.clone);
            }

            if (isDef)
            {
                // do not clone the lhs of an array create assignment.
                IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>();
                if (assignExpr.Expression is IArrayCreateExpression)
                {
                    return(iaie);
                }
            }

            IVariableDeclaration originalBaseVar = Recognizer.GetVariableDeclaration(iaie);

            // If the variable is not stochastic, return
            if (!CodeRecognizer.IsStochastic(context, originalBaseVar))
            {
                return(iaie);
            }

            IExpression          newExpr    = null;
            IVariableDeclaration baseVar    = originalBaseVar;
            IVariableDeclaration newvd      = null;
            IExpression          rhsExpr    = null;
            Containers           containers = info.containers;
            Type tp = iaie.GetExpressionType();

            if (tp == null)
            {
                Error("Could not determine type of expression: " + iaie);
                return(iaie);
            }
            var stmts      = Builder.StmtCollection();
            var stmtsAfter = Builder.StmtCollection();

            // does the expression have the form array[indices[k]][indices2[k]][indices3[k]]?
            if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
                iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
            {
                IArrayIndexerExpression index3 = (IArrayIndexerExpression)iaie.Indices[0];
                IArrayIndexerExpression iaie2  = (IArrayIndexerExpression)iaie.Target;
                if (index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression &&
                    iaie2.Target is IArrayIndexerExpression &&
                    iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression)
                {
                    IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie2.Indices[0];
                    IArrayIndexerExpression iaie3  = (IArrayIndexerExpression)iaie2.Target;
                    if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
                        iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression)
                    {
                        IArrayIndexerExpression      index      = (IArrayIndexerExpression)iaie3.Indices[0];
                        IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
                        IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                        if (index.Indices.Count == 1 && index2.Indices[0].Equals(innerIndex) &&
                            index3.Indices[0].Equals(innerIndex) &&
                            innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
                        {
                            // expression has the form array[indices[k]][indices2[k]][indices3[k]]
                            if (isDef)
                            {
                                Error("fancy indexing not allowed on left hand side");
                                return(iaie);
                            }
                            WarnIfLocal(index.Target, iaie3.Target, iaie);
                            WarnIfLocal(index2.Target, iaie3.Target, iaie);
                            WarnIfLocal(index3.Target, iaie3.Target, iaie);
                            containers = RemoveReferencesTo(containers, innerIndex);
                            IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
                            var         indices  = Recognizer.GetIndices(iaie);
                            // Build name of replacement variable from index values
                            StringBuilder sb = new StringBuilder("_item");
                            AppendIndexString(sb, iaie3);
                            AppendIndexString(sb, iaie2);
                            AppendIndexString(sb, iaie);
                            string name = ToString(iaie3.Target) + sb.ToString();
                            VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                            newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
                            if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                            {
                                context.InputAttributes.Set(newvd, new DerivedVariable());
                            }
                            IExpression getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
                                                                               new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
                            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                            stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                            newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
                            rhsExpr = getItems;
                        }
                    }
                }
            }
            // does the expression have the form array[indices[k]][indices2[k]]?
            if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
                iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
            {
                IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie.Indices[0];
                IArrayIndexerExpression target = (IArrayIndexerExpression)iaie.Target;
                if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
                    target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression)
                {
                    IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index2.Indices[0];
                    IArrayIndexerExpression      index      = (IArrayIndexerExpression)target.Indices[0];
                    IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                    if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) &&
                        innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target))
                    {
                        // expression has the form array[indices[k]][indices2[k]]
                        if (isDef)
                        {
                            Error("fancy indexing not allowed on left hand side");
                            return(iaie);
                        }
                        var innerLoops = new List <IForStatement>();
                        innerLoops.Add(innerLoop);
                        var indexTarget  = index.Target;
                        var index2Target = index2.Target;
                        // check if the index array is jagged, i.e. array[indices[k][j]]
                        while (indexTarget is IArrayIndexerExpression && index2Target is IArrayIndexerExpression)
                        {
                            IArrayIndexerExpression indexTargetExpr  = (IArrayIndexerExpression)indexTarget;
                            IArrayIndexerExpression index2TargetExpr = (IArrayIndexerExpression)index2Target;
                            if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression &&
                                index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression)
                            {
                                IVariableReferenceExpression innerIndexTarget  = (IVariableReferenceExpression)indexTargetExpr.Indices[0];
                                IVariableReferenceExpression innerIndex2Target = (IVariableReferenceExpression)index2TargetExpr.Indices[0];
                                IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget);
                                if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
                                    innerIndexTarget.Equals(innerIndex2Target))
                                {
                                    innerLoops.Add(indexTargetLoop);
                                    indexTarget  = indexTargetExpr.Target;
                                    index2Target = index2TargetExpr.Target;
                                }
                                else
                                {
                                    break;
                                }
                            }
                            else
                            {
                                break;
                            }
                        }
                        WarnIfLocal(indexTarget, target.Target, iaie);
                        WarnIfLocal(index2Target, target.Target, iaie);
                        innerLoops.Reverse();
                        var loopSizes    = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
                        var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
                        // Build name of replacement variable from index values
                        StringBuilder sb = new StringBuilder("_item");
                        AppendIndexString(sb, target);
                        AppendIndexString(sb, iaie);
                        string name = ToString(target.Target) + sb.ToString();
                        VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                        var indices = Recognizer.GetIndices(iaie);
                        newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
                        if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                        {
                            context.InputAttributes.Set(newvd, new DerivedVariable());
                        }
                        IExpression getItems;
                        if (innerLoops.Count == 1)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else if (innerLoops.Count == 2)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else
                        {
                            throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
                        }
                        context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                        stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                        var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
                        newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
                        rhsExpr = getItems;
                    }
                    else if (HasAnyCommonLoops(index, index2))
                    {
                        Warning($"This model will consume excess memory due to the indexing expression {iaie} since {index} and {index2} have larger depth than the compiler can handle.");
                    }
                }
            }
            if (newvd == null)
            {
                IArrayIndexerExpression originalExpr = iaie;
                if (UseGetItems)
                {
                    iaie = (IArrayIndexerExpression)base.ConvertArrayIndexer(iaie);
                }
                if (!object.ReferenceEquals(iaie.Target, originalExpr.Target) && false)
                {
                    // TODO: determine if this warning is useful or not
                    string warningText = "This model may consume excess memory due to the jagged indexing expression {0}";
                    Warning(string.Format(warningText, originalExpr));
                }

                // get the baseVar of the new expression.
                baseVar = Recognizer.GetVariableDeclaration(iaie);
                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);

                var indices = Recognizer.GetIndices(iaie);
                // Build name of replacement variable from index values
                StringBuilder sb = new StringBuilder("_item");
                AppendIndexString(sb, iaie);
                string name = ToString(iaie.Target) + sb.ToString();

                // does the expression have the form array[indices[k]]?
                if (UseGetItems && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
                {
                    IArrayIndexerExpression index = (IArrayIndexerExpression)iaie.Indices[0];
                    if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression)
                    {
                        // expression has the form array[indices[k]]
                        IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
                        IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                        if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie.Target, index.Target))
                        {
                            if (isDef)
                            {
                                Error("fancy indexing not allowed on left hand side");
                                return(iaie);
                            }
                            var innerLoops = new List <IForStatement>();
                            innerLoops.Add(innerLoop);
                            var indexTarget = index.Target;
                            // check if the index array is jagged, i.e. array[indices[k][j]]
                            while (indexTarget is IArrayIndexerExpression)
                            {
                                IArrayIndexerExpression index2 = (IArrayIndexerExpression)indexTarget;
                                if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression)
                                {
                                    IVariableReferenceExpression innerIndex2 = (IVariableReferenceExpression)index2.Indices[0];
                                    IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2);
                                    if (innerLoop2 != null && AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target))
                                    {
                                        innerLoops.Add(innerLoop2);
                                        indexTarget = index2.Target;
                                        // This limit must match the number of handled cases below.
                                        if (innerLoops.Count == 3)
                                        {
                                            break;
                                        }
                                    }
                                    else
                                    {
                                        break;
                                    }
                                }
                                else
                                {
                                    break;
                                }
                            }
                            WarnIfLocal(indexTarget, iaie.Target, originalExpr);
                            innerLoops.Reverse();
                            var loopSizes    = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
                            var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
                            newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
                            if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                            {
                                context.InputAttributes.Set(newvd, new DerivedVariable());
                            }
                            IExpression getItems;
                            if (innerLoops.Count == 1)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else if (innerLoops.Count == 2)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else if (innerLoops.Count == 3)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else
                            {
                                throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
                            }
                            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                            stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                            var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
                            newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
                            rhsExpr = getItems;
                        }
                    }
                }
                if (newvd == null)
                {
                    if (UseGetItems && info.count < 2)
                    {
                        return(iaie);
                    }
                    try
                    {
                        newvd = varInfo.DeriveIndexedVariable(stmts, context, name, indices, copyInitializer: isDef);
                    }
                    catch (Exception ex)
                    {
                        Error(ex.Message, ex);
                        return(iaie);
                    }
                    context.OutputAttributes.Remove <DerivedVariable>(newvd);
                    newExpr = Builder.VarRefExpr(newvd);
                    rhsExpr = iaie;
                    if (isDef)
                    {
                        // change:
                        //   array[0] = f()
                        // into:
                        //   array_item0 = f()
                        //   array[0] = Copy(array_item0)
                        IExpression copy   = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, newExpr);
                        IStatement  copySt = Builder.AssignStmt(iaie, copy);
                        stmtsAfter.Add(copySt);
                        if (!context.InputAttributes.Has <DerivedVariable>(baseVar))
                        {
                            context.InputAttributes.Set(baseVar, new DerivedVariable());
                        }
                    }
                    else if (!info.IsAssignedTo)
                    {
                        // change:
                        //   x = f(array[0])
                        // into:
                        //   array_item0 = Copy(array[0])
                        //   x = f(array_item0)
                        IExpression copy   = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, iaie);
                        IStatement  copySt = Builder.AssignStmt(Builder.VarRefExpr(newvd), copy);
                        //if (attr != null) context.OutputAttributes.Set(copySt, attr);
                        stmts.Add(copySt);
                        context.InputAttributes.Set(newvd, new DerivedVariable());
                    }
                }
            }

            // Reduce memory consumption by declaring the clone outside of unnecessary loops.
            // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item.
            containers = Containers.RemoveUnusedLoops(containers, context, rhsExpr);
            if (context.InputAttributes.Has <DoNotSendEvidence>(originalBaseVar))
            {
                containers = Containers.RemoveStochasticConditionals(containers, context);
            }
            if (true)
            {
                IStatement st = GetBindingSetContainer(FilterBindingSet(info.bindings,
                                                                        binding => Containers.ContainsExpression(containers.inputs, context, binding.GetExpression())));
                if (st != null)
                {
                    containers.Add(st);
                }
            }
            // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible,
            // then wrap the declaration with the remaining containers.
            int        ancIndex = containers.GetMatchingAncestorIndex(context);
            Containers missing  = containers.GetContainersNotInContext(context, ancIndex);

            stmts = Containers.WrapWithContainers(stmts, missing.outputs);
            context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            stmtsAfter = Containers.WrapWithContainers(stmtsAfter, missing.outputs);
            context.AddStatementsAfterAncestorIndex(ancIndex, stmtsAfter);
            context.InputAttributes.Set(newvd, containers);
            info.clone = newExpr;
            return(newExpr);
        }
コード例 #26
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);
            // Check if this is an index local variable
            if (baseVar == null) return expr;
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr;

            // Get the loop context for this variable
            LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar);
            if (lc == null)
            {
                Error("Loop context not found for '" + baseVar.Name + "'.");
                return expr;
            }

            // Get the reference loop context for this expression
            RefLoopContext rlc = lc.GetReferenceLoopContext(context);
            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.loops.Count == 0) return expr;

            // the set of loop variables that are constant wrt the expr
            Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>();
            constantLoopVars.AddRange(lc.loopVariables);

            // collect set of all loop variable indices in the expression
            Set<int> embeddedLoopIndices = new Set<int>();
            List<IList<IExpression>> brackets = Recognizer.GetIndices(expr);
            foreach (IList<IExpression> bracket in brackets)
            {
                foreach (IExpression index in bracket)
                {
                    IExpression indExpr = index;
                    if (indExpr is IBinaryExpression ibe)
                    {
                        indExpr = ibe.Left;
                    }
                    IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr);
                    if (indVar != null)
                    {
                        if (!constantLoopVars.Contains(indVar))
                        {
                            int loopIndex = rlc.loopVariables.IndexOf(indVar);
                            if (loopIndex != -1)
                            {
                                // indVar is a loop variable
                                constantLoopVars.Add(rlc.loopVariables[loopIndex]);
                            }
                            else 
                            {
                                // indVar is not a loop variable
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar);
                                foreach (var ivd in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                    else
                    {
                        foreach(var ivd in Recognizer.GetVariables(indExpr))
                        {
                            if (!constantLoopVars.Contains(ivd))
                            {
                                // copied from above
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
                                foreach (var ivd2 in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd2))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd2);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd2} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                }
            }

            // Find loop variables that must be constant due to condition statements.
            List<IStatement> ancestors = context.FindAncestors<IStatement>();
            foreach (IStatement ancestor in ancestors)
            {
                if (!(ancestor is IConditionStatement ics))
                    continue;
                ConditionBinding binding = new ConditionBinding(ics.Condition);
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs);
                IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs);
                int index = rlc.loopVariables.IndexOf(ivd);
                if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars))
                {
                    constantLoopVars.Add(ivd);
                    continue;
                }
                int index2 = rlc.loopVariables.IndexOf(ivd2);
                if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars))
                {
                    constantLoopVars.Add(ivd2);
                    continue;
                }
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get<Containers>(baseVar);

            IExpression originalExpr = expr;

            for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++)
            {
                IVariableDeclaration loopVar = rlc.loopVariables[currentLoop];
                if (constantLoopVars.Contains(loopVar))
                    continue;
                IForStatement loop = rlc.loops[currentLoop];
                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a loop.  Variables on the left hand side of an assignment must be indexed by all containing loops.");
                    continue;
                }
                if (embeddedLoopIndices.Contains(currentLoop))
                {
                    string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
                    Warning(string.Format(warningText, originalExpr, loopVar.Name));
                }
                // split expr into a target and extra indices, where target will be replicated and extra indices will be added later
                var extraIndices = new List<IEnumerable<IExpression>>();
                AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate);

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IExpression loopSize = Recognizer.LoopSizeExpression(loop);
                IList<IStatement> stmts = Builder.StmtCollection();
                List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate);
                IVariableDeclaration newIndexVar = loopVar;
                // if loopVar is already an indexVar of varInfo, create a new variable
                if (varInfo.HasIndexVar(loopVar))
                {
                    newIndexVar = VariableInformation.GenerateLoopVar(context, "_a");
                    context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar);
                }
                IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"),
                                                                          loopSize, newIndexVar, inds, useArrays: true);
                if (!context.InputAttributes.Has<DerivedVariable>(repVar))
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                if (context.InputAttributes.Has<ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate),
                    new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // RemoveUnusedLoops will also remove conditionals involving those loop variables.
                // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms.
                containers = Containers.RemoveUnusedLoops(containers, context, repMethod);
                if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context);
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                if (containers.Contains(loop))
                {
                    Error("Internal: invalid containers for replicating " + baseVar);
                    break;
                }
                int ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                context.OutputAttributes.Set(repVar, containers);
                List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex);
                foreach (IStatement container in missing.inputs)
                {
                    if (container is IForStatement ifs) loops.Add(ifs);
                }
                context.OutputAttributes.Set(repVar, new LoopContext(loops));
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar));
                expr = Builder.JaggedArrayIndex(expr, extraIndices);
            }

            return expr;
        }
コード例 #27
0
        protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            if (rhs is IArrayCreateExpression)
            {
                IArrayCreateExpression iace = (IArrayCreateExpression)rhs;
                bool zeroLength             = iace.Dimensions.All(dimExpr =>
                                                                  (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0));
                if (!zeroLength && iace.Initializer == null)
                {
                    return; // variable will have assignments to elements
                }
            }
            bool firstTime = !variablesAssigned.Contains(ivd);

            variablesAssigned.Add(ivd);
            bool isInferred   = context.InputAttributes.Has <IsInferred>(ivd);
            bool isStochastic = CodeRecognizer.IsStochastic(context, ivd);

            if (!isStochastic)
            {
                return;
            }
            VariableInformation vi            = VariableInformation.GetVariableInformation(context, ivd);
            Containers          defContainers = context.InputAttributes.Get <Containers>(ivd);
            int        ancIndex = defContainers.GetMatchingAncestorIndex(context);
            Containers missing  = defContainers.GetContainersNotInContext(context, ancIndex);
            // definition of a stochastic variable
            IExpression lhs = target;

            if (lhs is IVariableDeclarationExpression)
            {
                lhs = Builder.VarRefExpr(ivd);
            }
            IExpression defExpr = lhs;

            if (firstTime && isStochastic)
            {
                // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform
                ChannelInfo defChannel = ChannelInfo.DefChannel(vi);
                defChannel.decl = ivd;
                context.OutputAttributes.Set(ivd, defChannel);
            }
            bool       isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
            IAlgorithm algorithm = this.algorithmDefault;
            Algorithm  algAttr   = context.InputAttributes.Get <Algorithm>(ivd);

            if (algAttr != null)
            {
                algorithm = algAttr.algorithm;
            }
            if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime)
            {
                vi.DefineAllIndexVars(context);
                IList <IStatement>   stmts     = Builder.StmtCollection();
                IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv");
                context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl));
                ChannelInfo derivChannel = ChannelInfo.DefChannel(vi);
                derivChannel.decl = derivDecl;
                context.OutputAttributes.Set(derivChannel.decl, derivChannel);
                context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'"));
                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd);

            if (this.analysis.variablesExcludingVariableFactor.Contains(ivd))
            {
                this.variablesLackingVariableFactor.Add(ivd);
                // ivd will get a marginal channel in ConvertMethodInvoke
                useOfVariable[ivd] = ivd;
                return;
            }
            if (isDerived && !isInferred && !isPointEstimate)
            {
                return;
            }

            IExpression useExpr2 = null;

            if (firstTime)
            {
                // create marginal and use channels
                vi.DefineAllIndexVars(context);
                IList <IStatement> stmts = Builder.StmtCollection();

                CreateMarginalChannel(ivd, vi, stmts);
                if (isStochastic)
                {
                    CreateUseChannel(ivd, vi, stmts);
                    context.InputAttributes.Set(useOfVariable[ivd], defContainers);
                }

                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            if (isStochastic && !useOfVariable.ContainsKey(ivd))
            {
                Error("cannot find use channel of " + ivd);
                return;
            }
            IExpression  marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]);
            IExpression  useExpr      = isStochastic ? Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]) : marginalExpr;
            InitialiseTo it           = context.InputAttributes.Get <InitialiseTo>(ivd);

            Type[] genArgs = new Type[] { defExpr.GetExpressionType() };
            if (rhs is IMethodInvokeExpression)
            {
                IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)) && ancIndex < context.InputStack.Count - 2)
                {
                    IExpression          arg  = imie.Arguments[0];
                    IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg);
                    if (ivd2 != null && context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2))
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = arg;
                        shouldDelete = true;
                        bool makeClone = false;
                        if (makeClone)
                        {
                            VariableInformation         vi2      = VariableInformation.GetVariableInformation(context, ivd2);
                            IList <IStatement>          stmts    = Builder.StmtCollection();
                            List <IList <IExpression> > indices  = Recognizer.GetIndices(defExpr);
                            IVariableDeclaration        useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices);
                            useExpr2 = Builder.VarRefExpr(useDecl2);
                            Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2);
                            int        ancIndex2      = defContainers2.GetMatchingAncestorIndex(context);
                            Containers missing2       = defContainers2.GetContainersNotInContext(context, ancIndex2);
                            stmts = Containers.WrapWithContainers(stmts, missing2.outputs);
                            context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts);
                            context.InputAttributes.Set(useDecl2, defContainers2);

                            // TODO: call CreateUseChannel
                            ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2);
                            usageChannel.decl = useDecl2;
                            context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.OutputAttributes.Set(useDecl2, usageChannel);
                            //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'"));
                            context.OutputAttributes.Remove <InitialiseTo>(vi.declaration);

                            IExpression copyExpr = Builder.StaticGenericMethod(
                                new Func <PlaceHolder, PlaceHolder>(Clone.Copy), genArgs, useExpr2);
                            var copyStmt = Builder.AssignStmt(useExpr, copyExpr);
                            context.AddStatementAfterCurrent(copyStmt);
                        }
                    }
                }
            }

            // Add the variable factor
            IExpression variableFactorExpr;
            bool        isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd);

            if (isGateExitRandom)
            {
                variableFactorExpr = Builder.StaticGenericMethod(
                    new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable),
                    genArgs, defExpr, marginalExpr);
            }
            else
            {
                Delegate d = algorithm.GetVariableFactor(isDerived, it != null);
                if (isPointEstimate)
                {
                    d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint);
                }
                if (it == null)
                {
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr);
                }
                else
                {
                    IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression);
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr);
                }
            }
            context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr);
            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr);
            if (isStochastic)
            {
                context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor());
            }
            var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr);

            context.AddStatementAfterCurrent(assignStmt);
        }
コード例 #28
0
        private void RegisterDepth(IExpression expr, bool isDef)
        {
            if (Recognizer.IsBeingIndexed(context))
            {
                return;
            }
            int depth = Recognizer.GetIndexingDepth(expr);
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);

            // If not an indexed variable reference, skip it (e.g. an indexed argument reference)
            if (baseVar == null)
            {
                return;
            }
            // If the variable is not stochastic, skip it
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return;
            }
            ChannelInfo ci = context.InputAttributes.Get <ChannelInfo>(baseVar);

            if (ci != null && ci.IsMarginal)
            {
                return;
            }
            DepthInfo depthInfo;

            if (!depthInfos.TryGetValue(baseVar, out depthInfo))
            {
                depthInfo           = new DepthInfo();
                depthInfos[baseVar] = depthInfo;
            }
            if (isDef)
            {
                depthInfo.definitionDepth = depth;
                return;
            }
            depthInfo.useCount++;
            if (depth < depthInfo.minDepth)
            {
                depthInfo.minDepth = depth;
            }
            int literalIndexingDepth = 0;

            foreach (var bracket in Recognizer.GetIndices(expr))
            {
                if (!bracket.All(index => index is ILiteralExpression))
                {
                    break;
                }
                literalIndexingDepth++;
            }
            IndexInfo info;

            if (depthInfo.indexInfoOfDepth.TryGetValue(depth, out info))
            {
                Containers containers = new Containers(context);
                info.containers           = Containers.Intersect(info.containers, containers);
                info.literalIndexingDepth = System.Math.Min(info.literalIndexingDepth, literalIndexingDepth);
            }
            else
            {
                info                              = new IndexInfo();
                info.containers                   = new Containers(context);
                info.literalIndexingDepth         = literalIndexingDepth;
                depthInfo.indexInfoOfDepth[depth] = info;
            }
        }
コード例 #29
0
 protected override IExpression ConvertAssign(IAssignExpression iae)
 {
     foreach (IStatement stmt in context.FindAncestors <IStatement>())
     {
         // an initializer statement may perform a copy, but it is not valid to replace the lhs
         // in that case.
         if (context.InputAttributes.Has <Initializer>(stmt))
         {
             return(iae);
         }
     }
     // Look for assignments where the right hand side is a SetTo call
     if (iae.Expression is IMethodInvokeExpression imie)
     {
         bool isCopy                          = Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy));
         bool isSetTo                         = Recognizer.IsStaticGenericMethod(imie, typeof(ArrayHelper), "SetTo");
         bool isSetAllElementsTo              = Recognizer.IsStaticGenericMethod(imie, typeof(ArrayHelper), "SetAllElementsTo");
         bool isGetItemsPoint                 = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsPointOp <>), "ItemsAverageConditional");
         bool isGetJaggedItemsPoint           = Recognizer.IsStaticGenericMethod(imie, typeof(GetJaggedItemsPointOp <>), "ItemsAverageConditional");
         bool isGetDeepJaggedItemsPoint       = Recognizer.IsStaticGenericMethod(imie, typeof(GetDeepJaggedItemsPointOp <>), "ItemsAverageConditional");
         bool isGetItemsFromJaggedPoint       = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsFromJaggedPointOp <>), "ItemsAverageConditional");
         bool isGetItemsFromDeepJaggedPoint   = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsFromDeepJaggedPointOp <>), "ItemsAverageConditional");
         bool isGetJaggedItemsFromJaggedPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetJaggedItemsFromJaggedPointOp <>), "ItemsAverageConditional");
         if (isCopy || isSetTo || isSetAllElementsTo || isGetItemsPoint ||
             isGetJaggedItemsPoint || isGetDeepJaggedItemsPoint || isGetJaggedItemsFromJaggedPoint ||
             isGetItemsFromJaggedPoint || isGetItemsFromDeepJaggedPoint)
         {
             IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target);
             // Find the condition context
             var ifs         = context.FindAncestors <IConditionStatement>();
             var condContext = new List <IConditionStatement>();
             foreach (var ifSt in ifs)
             {
                 if (!CodeRecognizer.IsStochastic(context, ifSt.Condition))
                 {
                     condContext.Add(ifSt);
                 }
             }
             var         copyAttr = context.InputAttributes.GetOrCreate <CopyOfAttribute>(ivd, () => new CopyOfAttribute());
             IExpression rhs;
             if (isSetTo || isSetAllElementsTo)
             {
                 // Mark as copy of the second argument
                 rhs = imie.Arguments[1];
             }
             else
             {
                 // Mark as copy of the first argument
                 rhs = imie.Arguments[0];
             }
             InitialiseTo init = context.InputAttributes.Get <InitialiseTo>(ivd);
             if (init != null)
             {
                 IVariableDeclaration ivdRhs  = Recognizer.GetVariableDeclaration(rhs);
                 InitialiseTo         initRhs = (ivdRhs == null) ? null : context.InputAttributes.Get <InitialiseTo>(ivdRhs);
                 if (initRhs == null || !initRhs.initialMessagesExpression.Equals(init.initialMessagesExpression))
                 {
                     // Do not replace a variable with a unique initialiser
                     return(iae);
                 }
             }
             var initBack = context.InputAttributes.Get <InitialiseBackwardTo>(ivd);
             if (initBack != null && !(initBack.initialMessagesExpression is IArrayCreateExpression))
             {
                 IVariableDeclaration ivdRhs  = Recognizer.GetVariableDeclaration(rhs);
                 InitialiseBackwardTo initRhs = (ivdRhs == null) ? null : context.InputAttributes.Get <InitialiseBackwardTo>(ivdRhs);
                 if (initRhs == null || !initRhs.initialMessagesExpression.Equals(init.initialMessagesExpression))
                 {
                     // Do not replace a variable with a unique initialiser
                     return(iae);
                 }
             }
             if (isCopy || isSetTo)
             {
                 RemoveMatchingSuffixes(iae.Target, rhs, condContext, out IExpression lhsPrefix, out IExpression rhsPrefix);
                 copyAttr.copyMap[lhsPrefix] = new CopyOfAttribute.CopyContext {
                     Expression = rhsPrefix, ConditionContext = condContext
                 };
             }
             else if (isSetAllElementsTo)
             {
                 copyAttr.copiedInEveryElementMap[iae.Target] = new CopyOfAttribute.CopyContext {
                     Expression = rhs, ConditionContext = condContext
                 };
             }
             else if (isGetItemsPoint || isGetJaggedItemsPoint || isGetDeepJaggedItemsPoint || isGetItemsFromJaggedPoint || isGetItemsFromDeepJaggedPoint || isGetJaggedItemsFromJaggedPoint)
             {
                 var target     = ((IArrayIndexerExpression)iae.Target).Target;
                 int inputDepth = imie.Arguments.Count - 3;
                 List <IExpression> indexExprs = new List <IExpression>();
                 for (int i = 0; i < inputDepth; i++)
                 {
                     indexExprs.Add(imie.Arguments[1 + i]);
                 }
                 int outputDepth;
                 if (isGetDeepJaggedItemsPoint)
                 {
                     outputDepth = 3;
                 }
                 else if (isGetJaggedItemsPoint || isGetJaggedItemsFromJaggedPoint)
                 {
                     outputDepth = 2;
                 }
                 else
                 {
                     outputDepth = 1;
                 }
                 copyAttr.copyAtIndexMap[target] = new CopyOfAttribute.CopyContext2
                 {
                     Depth             = outputDepth,
                     ConditionContext  = condContext,
                     ExpressionAtIndex = (lhsIndices) =>
                     {
                         return(Builder.JaggedArrayIndex(rhs, indexExprs.ListSelect(indexExpr =>
                                                                                    new[] { Builder.JaggedArrayIndex(indexExpr, lhsIndices) })));
                     }
                 };
             }
             else
             {
                 throw new NotImplementedException();
             }
         }
     }
     return(iae);
 }
コード例 #30
0
        protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
        {
            IExpression converted = base.ConvertMethodInvoke(imie);

            if (converted is IMethodInvokeExpression mie)
            {
                if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, ICompilerAttribute, PlaceHolder>(Attrib.Var)))
                {
                    IVariableReferenceExpression ivre   = imie.Arguments[0] as IVariableReferenceExpression;
                    IVariableDeclaration         target = ivre.Variable.Resolve();
                    IExpression expr = imie.Arguments[1];
                    AddAttribute(target, expr);
                    return(null);
                }
                else if (Recognizer.IsStaticMethod(imie, new Action <object, object>(Attrib.InitialiseTo)))
                {
                    IVariableReferenceExpression ivre   = imie.Arguments[0] as IVariableReferenceExpression;
                    IVariableDeclaration         target = ivre.Variable.Resolve();
                    context.OutputAttributes.Set(target, new InitialiseTo(imie.Arguments[1]));
                    return(null);
                }
                else if (CodeRecognizer.IsInfer(imie))
                {
                    // the arguments must not be substituted for their values, so we don't call ConvertExpression
                    IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(imie.Arguments[0]);
                    if (ivd != null)
                    {
                        var       vi    = VariableInformation.GetVariableInformation(context, ivd);
                        QueryType query = (imie.Arguments.Count < 3) ? null : (QueryType)evaluator.Evaluate(imie.Arguments[2]);
                        vi.NeedsMarginalDividedByPrior = (query == QueryTypes.MarginalDividedByPrior);
                    }
                    return(imie);
                }
                bool anyArgumentIsLiteral = mie.Arguments.Any(arg => arg is ILiteralExpression);
                if (anyArgumentIsLiteral)
                {
                    if (Recognizer.IsStaticMethod(converted, new Func <bool, bool, bool>(Factors.Factor.And)))
                    {
                        if (mie.Arguments.Any(arg => arg is ILiteralExpression ile && ile.Value.Equals(false)))
                        {
                            return(Builder.LiteralExpr(false));
                        }
                        // any remaining literals must be true, and therefore can be ignored.
                        var reducedArguments = mie.Arguments.Where(arg => !(arg is ILiteralExpression));
                        if (reducedArguments.Count() == 1)
                        {
                            return(reducedArguments.First());
                        }
                        else
                        {
                            return(Builder.LiteralExpr(true));
                        }
                    }
                    else if (Recognizer.IsStaticMethod(converted, new Func <bool, bool, bool>(Factors.Factor.Or)))
                    {
                        if (mie.Arguments.Any(arg => arg is ILiteralExpression ile && ile.Value.Equals(true)))
                        {
                            return(Builder.LiteralExpr(true));
                        }
                        // any remaining literals must be false, and therefore can be ignored.
                        var reducedArguments = mie.Arguments.Where(arg => !(arg is ILiteralExpression));
                        if (reducedArguments.Count() == 1)
                        {
                            return(reducedArguments.First());
                        }
                        else
                        {
                            return(Builder.LiteralExpr(false));
                        }
                    }
                    else if (Recognizer.IsStaticMethod(converted, new Func <bool, bool>(Factors.Factor.Not)))
                    {
                        bool allArgumentsAreLiteral = mie.Arguments.All(arg => arg is ILiteralExpression);
                        if (allArgumentsAreLiteral)
                        {
                            return(Builder.LiteralExpr(evaluator.Evaluate(mie)));
                        }
                    }
                }
            }
            return(converted);
        }