private static void Register(RootDependencyContainer container)
 {
     if (!Containers.Contains(container))
     {
         Containers.Add(container);
     }
 }
Example #2
0
 /// <summary>
 /// Returns true if expr must be evaluated inside the given container.
 /// </summary>
 /// <param name="container"></param>
 /// <param name="expr"></param>
 /// <returns></returns>
 protected bool NeedsContainer(IStatement container, IExpression expr)
 {
     return Recognizer.GetVariables(expr).Any(ivd =>
     {
         Containers c = context.InputAttributes.Get<Containers>(ivd);
         return c.Contains(container);
     });
 }
        protected override IExpression ConvertAssign(IAssignExpression iae)
        {
            object     targetDecl = Recognizer.GetArrayDeclaration(iae.Target);
            IStatement increment;

            if (this.isOperatorStatement && targetDecl is IVariableDeclaration)
            {
                IVariableDeclaration ivd = (IVariableDeclaration)targetDecl;
                if (analysis.onUpdate.TryGetValue(targetDecl, out increment))
                {
                    IExpression incrExpr   = ((IExpressionStatement)increment).Expression;
                    Containers  containers = new Containers(context);
                    containers.AddContainersNeededForExpression(context, incrExpr);
                    containers = Containers.RemoveUnusedLoops(containers, context, incrExpr);
                    containers = Containers.RemoveStochasticConditionals(containers, context);
                    List <Containers> list;
                    if (!containersOfUpdate.TryGetValue(targetDecl, out list))
                    {
                        list = new List <Containers>();
                        containersOfUpdate[targetDecl] = list;
                    }
                    // have we already performed this update in these containers?
                    bool alreadyDone = false;
                    foreach (Containers prevContainers in list)
                    {
                        if (containers.Contains(prevContainers))
                        {
                            // prevContainers is more general, i.e. has fewer containers than 'containers'
                            alreadyDone = true;
                            break;
                        }
                    }
                    if (!alreadyDone)
                    {
                        list.Add(containers);
                        // must set this attribute before the statement is wrapped
                        context.OutputAttributes.Set(increment, new OperatorStatement());
                        int        ancIndex = containers.GetMatchingAncestorIndex(context);
                        Containers missing  = containers.GetContainersNotInContext(context, ancIndex);
                        increment = Containers.WrapWithContainers(increment, missing.outputs);
                        context.AddStatementAfterAncestorIndex(ancIndex, increment);
                    }
                }
                if (analysis.suppressUpdate.ContainsKey(ivd))
                {
                    foreach (IStatement ist in context.FindAncestors <IStatement>())
                    {
                        if (context.InputAttributes.Has <OperatorStatement>(ist))
                        {
                            var attr = analysis.suppressUpdate[ivd];
                            context.OutputAttributes.Set(ist, new HasIncrement(attr));
                            break;
                        }
                    }
                }
            }
            return(base.ConvertAssign(iae));
        }
Example #4
0
                internal bool IsValidContext(BasicTransformContext context)
                {
                    var containers = new Containers(context);

                    foreach (var ics in ConditionContext)
                    {
                        if (!containers.Contains(ics))
                        {
                            return(false);
                        }
                    }
                    return(true);
                }
Example #5
0
        /// <summary>
        /// Given an expression appearing in a set of containers, return an equivalent expression and the containers that it is valid in.
        /// </summary>
        /// <param name="expr"></param>
        /// <param name="containers"></param>
        /// <param name="exprContainers"></param>
        /// <returns></returns>
        internal IExpression GetNewExpression(IExpression expr, Containers containers, out Containers exprContainers)
        {
            // find all nodes whose containers are more general than the current containers
            var nodesAndScopes = nodeOf.GetAll(expr).Where(vis => containers.Contains(vis.Scope));

            foreach (var nodeInScope in nodesAndScopes)
            {
                int         node    = nodeInScope.Value;
                IExpression newExpr = newExpression[node];
                if (newExpr != null)
                {
                    exprContainers = nodeInScope.Scope;
                    return(newExpr);
                }
            }
            exprContainers = null;
            return(null);
        }
Example #6
0
            private bool AnyConditionsDependOnLoopVariable(List <IConditionStatement> condContext, IVariableDeclaration find)
            {
                IForStatement ifs = Recognizer.GetLoopForVariable(context, find);

                if (ifs == null)
                {
                    return(false);
                }
                return(condContext.Any(ics => Recognizer.GetVariables(ics.Condition).Any(ivd =>
                {
                    Containers c = context.InputAttributes.Get <Containers>(ivd);
                    if (c == null)
                    {
                        context.Error($"Containers not found for '{ivd.Name}'.");
                        return false;
                    }
                    return c.Contains(ifs);
                })));
            }
Example #7
0
        /// <summary>
        /// Creates the node (if needed) and adds directed edges to the graph.
        /// </summary>
        /// <param name="expr"></param>
        /// <returns></returns>
        private int CreateNodeAndEdges(IExpression expr)
        {
            Containers scope = new Containers(context);
            int        node;

            if (!nodeOf.TryGetExact(expr, scope, out node))
            {
                node = graph.AddNode();
                nodeOf.Add(expr, node, scope);
            }
            if (Recognizer.GetParameterDeclaration(expr) != null || expr is ILiteralExpression)
            {
                // expr is an observed value
                observedNodes.Add(new KeyValuePair <int, IExpression>(node, expr));
            }
            foreach (var vis in nodeOf.GetAll(expr))
            {
                // If scopes are equal (same node), do nothing
                if (vis.Value == node)
                {
                    continue;
                }

                if (scope.Contains(vis.Scope))
                {
                    // This scope contains the other scope, therefore is more specific
                    graph.AddEdge(vis.Value, node);
                }
                else if (vis.Scope.Contains(scope))
                {
                    // This scope is contained in the other scope, therefore is more general
                    graph.AddEdge(node, vis.Value);
                }
            }
            return(node);
        }
Example #8
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);
            // Check if this is an index local variable
            if (baseVar == null) return expr;
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr;

            // Get the loop context for this variable
            LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar);
            if (lc == null)
            {
                Error("Loop context not found for '" + baseVar.Name + "'.");
                return expr;
            }

            // Get the reference loop context for this expression
            RefLoopContext rlc = lc.GetReferenceLoopContext(context);
            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.loops.Count == 0) return expr;

            // the set of loop variables that are constant wrt the expr
            Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>();
            constantLoopVars.AddRange(lc.loopVariables);

            // collect set of all loop variable indices in the expression
            Set<int> embeddedLoopIndices = new Set<int>();
            List<IList<IExpression>> brackets = Recognizer.GetIndices(expr);
            foreach (IList<IExpression> bracket in brackets)
            {
                foreach (IExpression index in bracket)
                {
                    IExpression indExpr = index;
                    if (indExpr is IBinaryExpression ibe)
                    {
                        indExpr = ibe.Left;
                    }
                    IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr);
                    if (indVar != null)
                    {
                        if (!constantLoopVars.Contains(indVar))
                        {
                            int loopIndex = rlc.loopVariables.IndexOf(indVar);
                            if (loopIndex != -1)
                            {
                                // indVar is a loop variable
                                constantLoopVars.Add(rlc.loopVariables[loopIndex]);
                            }
                            else 
                            {
                                // indVar is not a loop variable
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar);
                                foreach (var ivd in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                    else
                    {
                        foreach(var ivd in Recognizer.GetVariables(indExpr))
                        {
                            if (!constantLoopVars.Contains(ivd))
                            {
                                // copied from above
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
                                foreach (var ivd2 in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd2))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd2);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd2} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                }
            }

            // Find loop variables that must be constant due to condition statements.
            List<IStatement> ancestors = context.FindAncestors<IStatement>();
            foreach (IStatement ancestor in ancestors)
            {
                if (!(ancestor is IConditionStatement ics))
                    continue;
                ConditionBinding binding = new ConditionBinding(ics.Condition);
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs);
                IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs);
                int index = rlc.loopVariables.IndexOf(ivd);
                if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars))
                {
                    constantLoopVars.Add(ivd);
                    continue;
                }
                int index2 = rlc.loopVariables.IndexOf(ivd2);
                if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars))
                {
                    constantLoopVars.Add(ivd2);
                    continue;
                }
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get<Containers>(baseVar);

            IExpression originalExpr = expr;

            for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++)
            {
                IVariableDeclaration loopVar = rlc.loopVariables[currentLoop];
                if (constantLoopVars.Contains(loopVar))
                    continue;
                IForStatement loop = rlc.loops[currentLoop];
                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a loop.  Variables on the left hand side of an assignment must be indexed by all containing loops.");
                    continue;
                }
                if (embeddedLoopIndices.Contains(currentLoop))
                {
                    string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
                    Warning(string.Format(warningText, originalExpr, loopVar.Name));
                }
                // split expr into a target and extra indices, where target will be replicated and extra indices will be added later
                var extraIndices = new List<IEnumerable<IExpression>>();
                AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate);

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IExpression loopSize = Recognizer.LoopSizeExpression(loop);
                IList<IStatement> stmts = Builder.StmtCollection();
                List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate);
                IVariableDeclaration newIndexVar = loopVar;
                // if loopVar is already an indexVar of varInfo, create a new variable
                if (varInfo.HasIndexVar(loopVar))
                {
                    newIndexVar = VariableInformation.GenerateLoopVar(context, "_a");
                    context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar);
                }
                IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"),
                                                                          loopSize, newIndexVar, inds, useArrays: true);
                if (!context.InputAttributes.Has<DerivedVariable>(repVar))
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                if (context.InputAttributes.Has<ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate),
                    new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // RemoveUnusedLoops will also remove conditionals involving those loop variables.
                // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms.
                containers = Containers.RemoveUnusedLoops(containers, context, repMethod);
                if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context);
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                if (containers.Contains(loop))
                {
                    Error("Internal: invalid containers for replicating " + baseVar);
                    break;
                }
                int ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                context.OutputAttributes.Set(repVar, containers);
                List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex);
                foreach (IStatement container in missing.inputs)
                {
                    if (container is IForStatement ifs) loops.Add(ifs);
                }
                context.OutputAttributes.Set(repVar, new LoopContext(loops));
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar));
                expr = Builder.JaggedArrayIndex(expr, extraIndices);
            }

            return expr;
        }