Пример #1
0
        protected override IStatement ConvertFor(IForStatement ifs)
        {
            IExpression sizeExpr = Recognizer.LoopSizeExpression(ifs);

            if (sizeExpr is ILiteralExpression)
            {
                int size = (int)((ILiteralExpression)sizeExpr).Value;
                if (size < 20)
                {
                    ConditionBinding binding = GetInitializerBinding(ifs);
                    if (binding.rhs is ILiteralExpression)
                    {
                        int start = (int)((ILiteralExpression)binding.rhs).Value;
                        for (int i = start; i < size; i++)
                        {
                            iterationContext.Add(new ConditionBinding(binding.lhs, Builder.LiteralExpr(i)));
                            IBlockStatement body = ConvertBlock(ifs.Body);
                            context.AddStatementsBeforeCurrent(body.Statements);
                            iterationContext.RemoveAt(iterationContext.Count - 1);
                        }
                        return(null);
                    }
                }
            }
            return(base.ConvertFor(ifs));
        }
Пример #2
0
        protected override IStatement ConvertCondition(IConditionStatement ics)
        {
            IConditionStatement cs = Builder.CondStmt();

            cs.Condition = ConvertExpression(ics.Condition);
            if (cs.Condition is ILiteralExpression)
            {
                bool value = (bool)((ILiteralExpression)cs.Condition).Value;
                if (value)
                {
                    if (ics.Then != null)
                    {
                        foreach (IStatement st in ics.Then.Statements)
                        {
                            IStatement ist = ConvertStatement(st);
                            if (ist != null)
                            {
                                context.AddStatementBeforeCurrent(ist);
                            }
                        }
                    }
                }
                else
                {
                    if (ics.Else != null)
                    {
                        foreach (IStatement st in ics.Else.Statements)
                        {
                            IStatement ist = ConvertStatement(st);
                            if (ist != null)
                            {
                                context.AddStatementBeforeCurrent(ist);
                            }
                        }
                    }
                }
                return(null);
            }
            context.SetPrimaryOutput(cs);
            IForStatement    loop;
            ConditionBinding binding = GateTransform.GetConditionBinding(cs.Condition, context, out loop);
            int startIndex           = conditionContext.Count;

            conditionContext.Add(binding);
            cs.Then = ConvertBlock(ics.Then);
            if (ics.Else != null)
            {
                conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
                binding = binding.FlipCondition();
                conditionContext.Add(binding);
                cs.Else = ConvertBlock(ics.Else);
            }
            conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
            if (cs.Then.Statements.Count == 0 && (cs.Else == null || cs.Else.Statements.Count == 0))
            {
                return(null);
            }
            return(cs);
        }
Пример #3
0
        public override bool Equals(object obj)
        {
            ConditionBinding that = obj as ConditionBinding;

            if (that == null)
            {
                return(false);
            }
            return(lhs.Equals(that.lhs) && rhs.Equals(that.rhs));
        }
Пример #4
0
            internal static BindingSet FromBindings(IEnumerable <ICollection <ConditionBinding> > bindingSet)
            {
                BindingSet result = new BindingSet();

                foreach (ICollection <ConditionBinding> bindings in bindingSet)
                {
                    // only boolean bindings are returned since we know there must be two cases.
                    // integer bindings could have only one case.
                    var dict = ConditionBinding.ToDictionary(bindings, true);
                    if (dict.Count > 0)
                    {
                        result.set.Add(dict);
                    }
                }
                return(result);
            }
        internal static List <ConditionBinding> GetBindings(BasicTransformContext context, IEnumerable <IStatement> containers)
        {
            List <ConditionBinding> bindings = new List <ConditionBinding>();

            foreach (IStatement st in containers)
            {
                if (st is IConditionStatement)
                {
                    IConditionStatement ics = (IConditionStatement)st;
                    if (!CodeRecognizer.IsStochastic(context, ics.Condition))
                    {
                        ConditionBinding binding = new ConditionBinding(ics.Condition);
                        bindings.Add(binding);
                    }
                }
            }
            return(bindings);
        }
Пример #6
0
        protected void RegisterDefinition(IVariableDeclaration ivd)
        {
            Set <ICollection <ConditionBinding> > defBindings;

            if (!definitionBindings.TryGetValue(ivd, out defBindings))
            {
                defBindings             = new Set <ICollection <ConditionBinding> >();
                definitionBindings[ivd] = defBindings;
            }
            else if (defBindings.Count == 0)
            {
                return;
            }
            if (conditionContext.Count > 0)
            {
                Set <ConditionBinding> bindings = ConditionBinding.Copy(conditionContext);
                bindings.Remove(declarationBindings[ivd]);
                if (bindings.Count > 0)
                {
                    defBindings.Add(bindings);
                }
            }
        }
        /// <summary>
        /// Assignments to non-stochastic non-loop integer variables are added to the conditionContext
        /// </summary>
        /// <param name="iae"></param>
        /// <returns></returns>
        protected override IExpression ConvertAssign(IAssignExpression iae)
        {
            IExpression          expr   = ConvertExpression(iae.Expression);
            IExpression          target = ConvertExpression(iae.Target);
            IVariableDeclaration ivd    = Recognizer.GetVariableDeclaration(target);

            if (ivd != null)
            {
                var varInfo = VariableInformation.GetVariableInformation(context, ivd);
                if (!varInfo.IsStochastic && varInfo.varType.Equals(typeof(int)) && Recognizer.GetLoopForVariable(context, ivd) == null)
                {
                    // add the assignment as a binding
                    if (target is IVariableDeclarationExpression ivde)
                    {
                        target = Builder.VarRefExpr(ivde.Variable);
                    }
                    ConditionBinding binding = new ConditionBinding(target, expr);
                    conditionContext.Add(binding);
                    // when current lexical scope ends, remove this binding?
                    // no, because locals aren't correctly scoped yet
                }
            }
            return(iae);
        }
        /// <summary>
        /// Analyse the condition body using an augmented conditionContext
        /// </summary>
        /// <param name="ics"></param>
        /// <returns></returns>
        protected override IStatement ConvertCondition(IConditionStatement ics)
        {
            if (CodeRecognizer.IsStochastic(context, ics.Condition))
            {
                return(base.ConvertCondition(ics));
            }
            // ics.Condition is not stochastic
            context.SetPrimaryOutput(ics);
            ConvertExpression(ics.Condition);
            ConditionBinding binding = new ConditionBinding(ics.Condition);
            int startIndex           = conditionContext.Count;

            conditionContext.Add(binding);
            ConvertBlock(ics.Then);
            if (ics.Else != null)
            {
                conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
                binding = binding.FlipCondition();
                conditionContext.Add(binding);
                ConvertBlock(ics.Else);
            }
            conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
            return(ics);
        }
Пример #9
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);
            // Check if this is an index local variable
            if (baseVar == null) return expr;
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr;

            // Get the loop context for this variable
            LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar);
            if (lc == null)
            {
                Error("Loop context not found for '" + baseVar.Name + "'.");
                return expr;
            }

            // Get the reference loop context for this expression
            RefLoopContext rlc = lc.GetReferenceLoopContext(context);
            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.loops.Count == 0) return expr;

            // the set of loop variables that are constant wrt the expr
            Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>();
            constantLoopVars.AddRange(lc.loopVariables);

            // collect set of all loop variable indices in the expression
            Set<int> embeddedLoopIndices = new Set<int>();
            List<IList<IExpression>> brackets = Recognizer.GetIndices(expr);
            foreach (IList<IExpression> bracket in brackets)
            {
                foreach (IExpression index in bracket)
                {
                    IExpression indExpr = index;
                    if (indExpr is IBinaryExpression ibe)
                    {
                        indExpr = ibe.Left;
                    }
                    IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr);
                    if (indVar != null)
                    {
                        if (!constantLoopVars.Contains(indVar))
                        {
                            int loopIndex = rlc.loopVariables.IndexOf(indVar);
                            if (loopIndex != -1)
                            {
                                // indVar is a loop variable
                                constantLoopVars.Add(rlc.loopVariables[loopIndex]);
                            }
                            else 
                            {
                                // indVar is not a loop variable
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar);
                                foreach (var ivd in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                    else
                    {
                        foreach(var ivd in Recognizer.GetVariables(indExpr))
                        {
                            if (!constantLoopVars.Contains(ivd))
                            {
                                // copied from above
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
                                foreach (var ivd2 in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd2))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd2);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd2} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                }
            }

            // Find loop variables that must be constant due to condition statements.
            List<IStatement> ancestors = context.FindAncestors<IStatement>();
            foreach (IStatement ancestor in ancestors)
            {
                if (!(ancestor is IConditionStatement ics))
                    continue;
                ConditionBinding binding = new ConditionBinding(ics.Condition);
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs);
                IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs);
                int index = rlc.loopVariables.IndexOf(ivd);
                if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars))
                {
                    constantLoopVars.Add(ivd);
                    continue;
                }
                int index2 = rlc.loopVariables.IndexOf(ivd2);
                if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars))
                {
                    constantLoopVars.Add(ivd2);
                    continue;
                }
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get<Containers>(baseVar);

            IExpression originalExpr = expr;

            for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++)
            {
                IVariableDeclaration loopVar = rlc.loopVariables[currentLoop];
                if (constantLoopVars.Contains(loopVar))
                    continue;
                IForStatement loop = rlc.loops[currentLoop];
                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a loop.  Variables on the left hand side of an assignment must be indexed by all containing loops.");
                    continue;
                }
                if (embeddedLoopIndices.Contains(currentLoop))
                {
                    string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
                    Warning(string.Format(warningText, originalExpr, loopVar.Name));
                }
                // split expr into a target and extra indices, where target will be replicated and extra indices will be added later
                var extraIndices = new List<IEnumerable<IExpression>>();
                AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate);

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IExpression loopSize = Recognizer.LoopSizeExpression(loop);
                IList<IStatement> stmts = Builder.StmtCollection();
                List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate);
                IVariableDeclaration newIndexVar = loopVar;
                // if loopVar is already an indexVar of varInfo, create a new variable
                if (varInfo.HasIndexVar(loopVar))
                {
                    newIndexVar = VariableInformation.GenerateLoopVar(context, "_a");
                    context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar);
                }
                IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"),
                                                                          loopSize, newIndexVar, inds, useArrays: true);
                if (!context.InputAttributes.Has<DerivedVariable>(repVar))
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                if (context.InputAttributes.Has<ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate),
                    new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // RemoveUnusedLoops will also remove conditionals involving those loop variables.
                // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms.
                containers = Containers.RemoveUnusedLoops(containers, context, repMethod);
                if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context);
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                if (containers.Contains(loop))
                {
                    Error("Internal: invalid containers for replicating " + baseVar);
                    break;
                }
                int ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                context.OutputAttributes.Set(repVar, containers);
                List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex);
                foreach (IStatement container in missing.inputs)
                {
                    if (container is IForStatement ifs) loops.Add(ifs);
                }
                context.OutputAttributes.Set(repVar, new LoopContext(loops));
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar));
                expr = Builder.JaggedArrayIndex(expr, extraIndices);
            }

            return expr;
        }
Пример #10
0
        protected override IStatement ConvertCondition(IConditionStatement ics)
        {
            context.SetPrimaryOutput(ics);
            ConvertExpression(ics.Condition);
            ConditionBinding binding   = GateTransform.GetConditionBinding(ics.Condition, context, out IForStatement loop);
            IExpression      caseValue = binding.rhs;

            if (!GateTransform.IsLiteralOrLoopVar(context, caseValue, out loop))
            {
                Error("If statement condition must compare to a literal or loop counter, was: " + ics.Condition);
                return(ics);
            }
            bool        isStochastic = CodeRecognizer.IsStochastic(context, binding.lhs);
            IExpression gateBlockKey;

            if (isStochastic)
            {
                gateBlockKey = binding.lhs;
            }
            else
            {
                // definitions must not be unified across deterministic gate conditions
                gateBlockKey = binding.GetExpression();
            }
            GateBlock gateBlock             = null;
            Set <ConditionBinding> bindings = ConditionBinding.Copy(conditionContext);
            Dictionary <IExpression, GateBlock> blockMap;

            if (!gateBlocks.TryGetValue(bindings, out blockMap))
            {
                // first time seeing these bindings
                blockMap             = new Dictionary <IExpression, GateBlock>();
                gateBlocks[bindings] = blockMap;
            }
            if (!blockMap.TryGetValue(gateBlockKey, out gateBlock))
            {
                // first time seeing this lhs
                gateBlock = new GateBlock();
                blockMap[gateBlockKey] = gateBlock;
            }
            if (gateBlock.hasLoopCaseValue && loop == null)
            {
                Error("Cannot compare " + binding.lhs + " to a literal, since it was previously compared to a loop counter.  Put this test inside the loop.");
            }
            if (!gateBlock.hasLoopCaseValue && gateBlock.caseValues.Count > 0 && loop != null)
            {
                Error("Cannot compare " + binding.lhs + " to a loop counter, since it was previously compared to a literal.  Put the literal case inside the loop.");
            }
            gateBlock.caseValues.Add(caseValue);
            if (loop != null)
            {
                gateBlock.hasLoopCaseValue = true;
            }
            gateBlockContext.Add(gateBlock);
            context.OutputAttributes.Set(ics, gateBlock);
            int startIndex = conditionContext.Count;

            conditionContext.Add(binding);
            ConvertBlock(ics.Then);
            if (ics.Else != null)
            {
                conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
                binding = binding.FlipCondition();
                conditionContext.Add(binding);
                ConvertBlock(ics.Else);
            }
            conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
            gateBlockContext.RemoveAt(gateBlockContext.Count - 1);
            // remove any uses that match a def
            //RemoveUsesOfDefs(gateBlock);
            if (gateBlockContext.Count > 0)
            {
                GateBlock currentBlock = gateBlockContext[gateBlockContext.Count - 1];
                // all variables defined/used in the inner block must be processed by the outer block
                foreach (ExpressionWithBindings eb in gateBlock.variablesDefined.Values)
                {
                    if (eb.Bindings.Count > 0)
                    {
                        foreach (List <ConditionBinding> binding2 in eb.Bindings)
                        {
                            ProcessUse(eb.Expression, true, Union(conditionContext, binding2));
                        }
                    }
                    else
                    {
                        ProcessUse(eb.Expression, true, conditionContext);
                    }
                }
                foreach (List <ExpressionWithBindings> ebs in gateBlock.variablesUsed.Values)
                {
                    foreach (ExpressionWithBindings eb in ebs)
                    {
                        if (eb.Bindings.Count > 0)
                        {
                            foreach (ICollection <ConditionBinding> binding2 in eb.Bindings)
                            {
                                ProcessUse(eb.Expression, false, Union(conditionContext, binding2));
                            }
                        }
                        else
                        {
                            ProcessUse(eb.Expression, false, conditionContext);
                        }
                    }
                }
            }
            return(ics);
        }
Пример #11
0
 protected void RegisterDeclaration(IVariableDeclaration ivd)
 {
     declarationBindings[ivd] = ConditionBinding.Copy(conditionContext);
 }
Пример #12
0
        protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie)
        {
            bool isDef = Recognizer.IsBeingMutated(context, iaie);

            if (isDef)
            {
                // do not clone the lhs of an array create assignment.
                IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>();
                if (assignExpr.Expression is IArrayCreateExpression)
                {
                    return(iaie);
                }
            }
            base.ConvertArrayIndexer(iaie);
            IndexInfo info;

            // TODO: Instead of storing an IndexInfo for each distinct expression, we should try to unify expressions, as in GateAnalysisTransform.
            // For example, we could unify a[0,i] and a[0,0] and use the same clone array for both.
            if (indexInfoOf.TryGetValue(iaie, out info))
            {
                Containers containers = new Containers(context);
                if (info.bindings.Count > 0)
                {
                    List <ConditionBinding> bindings = GetBindings(context, containers.inputs);
                    if (bindings.Count == 0)
                    {
                        info.bindings.Clear();
                    }
                    else
                    {
                        info.bindings.Add(bindings);
                    }
                }
                info.containers = Containers.Intersect(info.containers, containers);
                info.count++;
                if (isDef)
                {
                    info.IsAssignedTo = true;
                }
                return(iaie);
            }
            CheckIndicesAreNotStochastic(iaie.Indices);
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(iaie);

            // If not an indexed variable reference, skip it (e.g. an indexed argument reference)
            if (baseVar == null)
            {
                return(iaie);
            }
            // If the variable is not stochastic, skip it
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return(iaie);
            }
            // If the indices are all loop variables, skip it
            var  indices        = Recognizer.GetIndices(iaie);
            bool allLoopIndices = indices.All(bracket => bracket.All(indexExpr =>
            {
                if (indexExpr is IVariableReferenceExpression ivre)
                {
                    return(Recognizer.GetLoopForVariable(context, ivre) != null);
                }
                else
                {
                    return(false);
                }
            }));

            if (allLoopIndices)
            {
                return(iaie);
            }

            info            = new IndexInfo();
            info.containers = new Containers(context);
            List <ConditionBinding> bindings2 = GetBindings(context, info.containers.inputs);

            if (bindings2.Count > 0)
            {
                info.bindings.Add(bindings2);
            }
            info.count        = 1;
            info.IsAssignedTo = isDef;
            indexInfoOf[iaie] = info;
            return(iaie);

            List <ConditionBinding> GetBindings(BasicTransformContext context, IEnumerable <IStatement> containers)
            {
                List <ConditionBinding> bindings = new List <ConditionBinding>();

                foreach (IStatement st in containers)
                {
                    if (st is IConditionStatement ics)
                    {
                        if (!CodeRecognizer.IsStochastic(context, ics.Condition))
                        {
                            ConditionBinding binding = new ConditionBinding(ics.Condition);
                            bindings.Add(binding);
                        }
                    }
                }
                return(bindings);
            }

            void CheckIndicesAreNotStochastic(IList <IExpression> exprs)
            {
                foreach (IExpression index in exprs)
                {
                    foreach (var ivd in Recognizer.GetVariables(index))
                    {
                        if (CodeRecognizer.IsStochastic(context, ivd))
                        {
                            string msg = "Indexing by a random variable '" + ivd.Name + "'.  You must wrap this statement with Variable.Switch(" + index + ")";
                            Error(msg);
                        }
                    }
                }
            }
        }