Exemple #1
0
        protected override IExpression ConvertVariableDeclExpr(IVariableDeclarationExpression ivde)
        {
            context.InputAttributes.Remove <Containers>(ivde.Variable);
            context.InputAttributes.Set(ivde.Variable, new Containers(context));
            IVariableDeclaration ivd = ivde.Variable;

            if (!CodeRecognizer.IsStochastic(context, ivd))
            {
                ProcessConstant(ivd);
                if (!context.InputAttributes.Has <DescriptionAttribute>(ivd))
                {
                    context.OutputAttributes.Set(ivd, new DescriptionAttribute("The constant '" + ivd.Name + "'"));
                }
                return(ivde);
            }
            VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd);
            // Ensure the marginal prototype is set.
            MarginalPrototype mpa = Context.InputAttributes.Get <MarginalPrototype>(ivd);

            try
            {
                vi.SetMarginalPrototypeFromAttribute(mpa);
            }
            catch (ArgumentException ex)
            {
                Error(ex.Message);
            }
            return(ivde);
        }
Exemple #2
0
        internal static Containers SortStochasticConditionals(Containers containers, BasicTransformContext context)
        {
            Containers result       = new Containers();
            Containers conditionals = new Containers();

            for (int i = 0; i < containers.inputs.Count; i++)
            {
                IStatement container = containers.inputs[i];
                if (container is IConditionStatement)
                {
                    IConditionStatement ics = (IConditionStatement)container;
                    if (CodeRecognizer.IsStochastic(context, ics.Condition))
                    {
                        conditionals.inputs.Add(container);
                        conditionals.outputs.Add(containers.outputs[i]);
                        continue;
                    }
                }
                result.inputs.Add(container);
                result.outputs.Add(containers.outputs[i]);
            }
            for (int i = 0; i < conditionals.inputs.Count; i++)
            {
                result.inputs.Add(conditionals.inputs[i]);
                result.outputs.Add(conditionals.outputs[i]);
            }
            return(result);
        }
Exemple #3
0
 private List <ConditionBinding> FilterConditionContext(GateBlock gateBlock, List <ConditionBinding> conditionContext)
 {
     return(conditionContext
            .Where(binding => !CodeRecognizer.IsStochastic(context, binding.lhs) &&
                   !ContainsLocalVars(gateBlock, binding.lhs) &&
                   !ContainsLocalVars(gateBlock, binding.rhs))
            .ToList());
 }
        protected override IStatement ConvertExpressionStatement(IExpressionStatement ies)
        {
            if (parent == null)
            {
                return(ies);
            }
            bool keepIfStatement = false;
            // Only keep the surrounding if statement when a factor or constraint is being added.
            IExpression expr = ies.Expression;

            if (expr is IMethodInvokeExpression)
            {
                keepIfStatement = true;
                if (CodeRecognizer.IsInfer(expr))
                {
                    keepIfStatement = false;
                }
            }
            else if (expr is IAssignExpression)
            {
                keepIfStatement = false;
                IAssignExpression       iae  = (IAssignExpression)expr;
                IMethodInvokeExpression imie = iae.Expression as IMethodInvokeExpression;
                if (imie != null)
                {
                    keepIfStatement = true;
                    if (imie.Arguments.Count > 0)
                    {
                        // Statements that copy evidence variables should not send evidence messages.
                        IVariableDeclaration ivd    = Recognizer.GetVariableDeclaration(iae.Target);
                        IVariableDeclaration ivdArg = Recognizer.GetVariableDeclaration(imie.Arguments[0]);
                        if (ivd != null && context.InputAttributes.Has <DoNotSendEvidence>(ivd) &&
                            ivdArg != null && context.InputAttributes.Has <DoNotSendEvidence>(ivdArg))
                        {
                            keepIfStatement = false;
                        }
                    }
                }
                else
                {
                    expr = iae.Target;
                }
            }
            if (expr is IVariableDeclarationExpression)
            {
                IVariableDeclarationExpression ivde = (IVariableDeclarationExpression)expr;
                IVariableDeclaration           ivd  = ivde.Variable;
                keepIfStatement = CodeRecognizer.IsStochastic(context, ivd) && !context.InputAttributes.Has <DoNotSendEvidence>(ivd);
            }
            if (!keepIfStatement)
            {
                return(ies);
            }
            IConditionStatement cs = Builder.CondStmt(parent.Condition, Builder.BlockStmt());

            cs.Then.Statements.Add(ies);
            return(cs);
        }
Exemple #5
0
 /// <summary>
 /// Returns true if ivd is constant for a given value of the loop variables in lc
 /// </summary>
 /// <param name="ivd"></param>
 /// <param name="constantLoopVars"></param>
 /// <returns></returns>
 private bool IsConstantWrtLoops(IVariableDeclaration ivd, Set<IVariableDeclaration> constantLoopVars)
 {
     if (ivd == null) return true;
     if (CodeRecognizer.IsStochastic(context, ivd)) return false;
     if (constantLoopVars.Contains(ivd)) return true;
     LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
     // return false if lc2 has any loops that are not in constantLoopVars
     return constantLoopVars.ContainsAll(lc2.loopVariables);
 }
Exemple #6
0
        private bool ArgumentIsPointMass(IExpression arg)
        {
            bool IsOut = (arg is IAddressOutExpression);

            if (CodeRecognizer.IsStochastic(context, arg) && !IsOut)
            {
                IVariableDeclaration argVar = Recognizer.GetVariableDeclaration(arg);
                return((argVar != null) && context.InputAttributes.Has <ForwardPointMass>(argVar));
            }
            else
            {
                return(true);
            }
        }
 /// <summary>
 /// Raise an error if any expression is stochastic.
 /// </summary>
 /// <param name="exprs"></param>
 private void CheckIndicesAreNotStochastic(IList <IExpression> exprs)
 {
     foreach (IExpression index in exprs)
     {
         foreach (var ivd in Recognizer.GetVariables(index))
         {
             if (CodeRecognizer.IsStochastic(context, ivd))
             {
                 string msg = "Indexing by a random variable '" + ivd.Name + "'.  You must wrap this statement with Variable.Switch(" + index + ")";
                 Error(msg);
             }
         }
     }
 }
        internal static List <ConditionBinding> GetBindings(BasicTransformContext context, IEnumerable <IStatement> containers)
        {
            List <ConditionBinding> bindings = new List <ConditionBinding>();

            foreach (IStatement st in containers)
            {
                if (st is IConditionStatement)
                {
                    IConditionStatement ics = (IConditionStatement)st;
                    if (!CodeRecognizer.IsStochastic(context, ics.Condition))
                    {
                        ConditionBinding binding = new ConditionBinding(ics.Condition);
                        bindings.Add(binding);
                    }
                }
            }
            return(bindings);
        }
Exemple #9
0
        protected void ProcessExpression(IExpression expr)
        {
            bool isDef = Recognizer.IsBeingMutated(context, expr);
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr);

            if (ivd == null)
            {
                return;
            }
            if (!CodeRecognizer.IsStochastic(context, ivd))
            {
                return;
            }
            if (isDef && !Recognizer.IsBeingAllocated(context, expr))
            {
                RegisterDefinition(ivd);
            }
            ProcessUse(expr, isDef, conditionContext);
        }
        /// <summary>
        /// Analyse the condition body using an augmented conditionContext
        /// </summary>
        /// <param name="ics"></param>
        /// <returns></returns>
        protected override IStatement ConvertCondition(IConditionStatement ics)
        {
            if (CodeRecognizer.IsStochastic(context, ics.Condition))
            {
                return(base.ConvertCondition(ics));
            }
            // ics.Condition is not stochastic
            context.SetPrimaryOutput(ics);
            ConvertExpression(ics.Condition);
            ConditionBinding binding = new ConditionBinding(ics.Condition);
            int startIndex           = conditionContext.Count;

            conditionContext.Add(binding);
            ConvertBlock(ics.Then);
            if (ics.Else != null)
            {
                conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
                binding = binding.FlipCondition();
                conditionContext.Add(binding);
                ConvertBlock(ics.Else);
            }
            conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
            return(ics);
        }
Exemple #11
0
        /// <summary>
        /// Adds a Replicate statement to the output code, if needed.
        /// </summary>
        /// <param name="target">LHS of assignment</param>
        /// <param name="rhs">RHS of assignment</param>
        /// <param name="shouldDelete">True if the original assignment statement should be deleted, i.e. it can be optimized away.</param>
        protected void AddReplicateStatement(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd);

            if (!vi.IsStochastic)
            {
                ProcessConstant(ivd);
                return;
            }
            bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(ivd);

            if (SwapIndices && replicateEvidenceVars != isEvidenceVar)
            {
                return;
            }
            // iae defines a stochastic variable
            ChannelAnalysisTransform.UsageInfo usageInfo;
            if (!analysis.usageInfo.TryGetValue(ivd, out usageInfo))
            {
                return;
            }
            int useCount = usageInfo.NumberOfUses;

            if (useCount <= 1)
            {
                return;
            }

            VariableToChannelInformation vtci;
            bool firstTime   = !usesOfVariable.TryGetValue(ivd, out vtci);
            int  targetDepth = Recognizer.GetIndexingDepth(target);
            int  minDepth    = usageInfo.indexingDepths[0];
            int  usageDepth  = minDepth;

            if (!SwapIndices)
            {
                usageDepth = 0;
            }
            Containers defContainers = context.InputAttributes.Get <Containers>(ivd);
            int        ancIndex      = defContainers.GetMatchingAncestorIndex(context);
            Containers missing       = defContainers.GetContainersNotInContext(context, ancIndex);

            if (firstTime)
            {
                // declaration of uses array
                IList <IStatement> stmts = Builder.StmtCollection();
                vtci  = DeclareUsesArray(stmts, ivd, vi, useCount, usageDepth);
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            // check that extra literal indices in the target are zero.
            // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i])
            // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1])
            // but not x_uses[i] = Rep(x[i]) since this will be a duplicate.
            bool extraLiteralsAreZero = CheckExtraLiteralsAreZero(target, targetDepth, usageDepth);

            if (extraLiteralsAreZero)
            {
                // definition of uses array
                IExpression       defExpr  = Builder.VarRefExpr(ivd);
                IExpression       usesExpr = Builder.VarRefExpr(vtci.usesDecl);
                List <IStatement> loops    = new List <IStatement>();
                if (usageDepth == targetDepth)
                {
                    defExpr = target;
                    if (defExpr is IVariableDeclarationExpression)
                    {
                        defExpr = Builder.VarRefExpr(ivd);
                    }
                    usesExpr = Builder.ReplaceVariable(defExpr, ivd, vtci.usesDecl);
                }
                else
                {
                    // loops over the last indexing bracket
                    for (int d = 0; d < usageDepth; d++)
                    {
                        List <IExpression> indices = new List <IExpression>();
                        for (int i = 0; i < vi.sizes[d].Length; i++)
                        {
                            IVariableDeclaration v    = vi.indexVars[d][i];
                            IStatement           loop = Builder.ForStmt(v, vi.sizes[d][i]);
                            loops.Add(loop);
                            indices.Add(Builder.VarRefExpr(v));
                        }
                        defExpr  = Builder.ArrayIndex(defExpr, indices);
                        usesExpr = Builder.ArrayIndex(usesExpr, indices);
                    }
                }

                if (rhs != null && rhs is IMethodInvokeExpression)
                {
                    IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                    bool copyPropagation         = false;
                    if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Factor.Copy)) && copyPropagation)
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = imie.Arguments[0];
                        shouldDelete = true;
                    }
                }

                // Add the statement:
                //   x_uses = Replicate(x, useCount)
                var genArgs = new Type[] { ivd.VariableType.DotNetType };
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func <PlaceHolder, int, PlaceHolder[]>(Factor.Replicate),
                    genArgs, defExpr, Builder.LiteralExpr(useCount));
                bool isGateExitRandom = context.InputAttributes.Has <Algorithms.VariationalMessagePassing.GateExitRandomVariable>(ivd);
                if (isGateExitRandom)
                {
                    repMethod = Builder.StaticGenericMethod(
                        new Func <PlaceHolder, int, PlaceHolder[]>(Gate.ReplicateExiting),
                        genArgs, defExpr, Builder.LiteralExpr(useCount));
                }
                context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, repMethod);
                if (context.InputAttributes.Has <DivideMessages>(ivd))
                {
                    context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, repMethod);
                }
                else if (useCount == 2)
                {
                    // division has no benefit for 2 uses, and degrades the schedule
                    context.OutputAttributes.Set(repMethod, new DivideMessages(false));
                }
                context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, repMethod);
                IStatement repSt = Builder.AssignStmt(usesExpr, repMethod);
                if (usageDepth == targetDepth)
                {
                    if (isEvidenceVar)
                    {
                        // place the Replicate after the current statement, but outside of any evidence conditionals
                        ancIndex = context.FindAncestorIndex <IStatement>();
                        for (int i = ancIndex - 2; i > 0; i--)
                        {
                            object ancestor = context.GetAncestor(i);
                            if (ancestor is IConditionStatement)
                            {
                                IConditionStatement ics = (IConditionStatement)ancestor;
                                if (CodeRecognizer.IsStochastic(context, ics.Condition))
                                {
                                    ancIndex = i;
                                    break;
                                }
                            }
                        }
                        context.AddStatementAfterAncestorIndex(ancIndex, repSt);
                    }
                    else
                    {
                        context.AddStatementAfterCurrent(repSt);
                    }
                }
                else if (firstTime)
                {
                    // place Replicate after assignment but outside of definition loops (can't have any uses there)
                    repSt = Containers.WrapWithContainers(repSt, loops);
                    repSt = Containers.WrapWithContainers(repSt, missing.outputs);
                    context.AddStatementAfterAncestorIndex(ancIndex, repSt);
                }
            }
        }
        /// <summary>
        /// This method does all the work of converting literal indexing expressions.
        /// </summary>
        /// <param name="iaie"></param>
        /// <returns></returns>
        protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie)
        {
            IndexAnalysisTransform.IndexInfo info;
            if (!analysis.indexInfoOf.TryGetValue(iaie, out info))
            {
                return(base.ConvertArrayIndexer(iaie));
            }
            // Determine if this is a definition i.e. the variable is on the left hand side of an assignment
            // This must be done before base.ConvertArrayIndexer changes the expression!
            bool isDef = Recognizer.IsBeingMutated(context, iaie);

            if (info.clone != null)
            {
                if (isDef)
                {
                    // check that extra literal indices in the target are zero.
                    // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i])
                    // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1])
                    // but not x_uses[i] = Rep(x[i]) since this will be a duplicate.
                    bool   extraLiteralsAreZero = true;
                    int    parentIndex          = context.InputStack.Count - 2;
                    object parent = context.GetAncestor(parentIndex);
                    while (parent is IArrayIndexerExpression)
                    {
                        IArrayIndexerExpression parent_iaie = (IArrayIndexerExpression)parent;
                        foreach (IExpression index in parent_iaie.Indices)
                        {
                            if (index is ILiteralExpression)
                            {
                                int value = (int)((ILiteralExpression)index).Value;
                                if (value != 0)
                                {
                                    extraLiteralsAreZero = false;
                                    break;
                                }
                            }
                        }
                        parentIndex--;
                        parent = context.GetAncestor(parentIndex);
                    }
                    if (false && extraLiteralsAreZero)
                    {
                        // change:
                        //   array[0] = f()
                        // into:
                        //   array_item0 = f()
                        //   array[0] = Copy(array_item0)
                        IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { iaie.GetExpressionType() },
                                                                       info.clone);
                        IStatement copySt = Builder.AssignStmt(iaie, copy);
                        context.AddStatementAfterCurrent(copySt);
                    }
                }
                return(info.clone);
            }

            if (isDef)
            {
                // do not clone the lhs of an array create assignment.
                IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>();
                if (assignExpr.Expression is IArrayCreateExpression)
                {
                    return(iaie);
                }
            }

            IVariableDeclaration originalBaseVar = Recognizer.GetVariableDeclaration(iaie);

            // If the variable is not stochastic, return
            if (!CodeRecognizer.IsStochastic(context, originalBaseVar))
            {
                return(iaie);
            }

            IExpression          newExpr    = null;
            IVariableDeclaration baseVar    = originalBaseVar;
            IVariableDeclaration newvd      = null;
            IExpression          rhsExpr    = null;
            Containers           containers = info.containers;
            Type tp = iaie.GetExpressionType();

            if (tp == null)
            {
                Error("Could not determine type of expression: " + iaie);
                return(iaie);
            }
            var stmts      = Builder.StmtCollection();
            var stmtsAfter = Builder.StmtCollection();

            // does the expression have the form array[indices[k]][indices2[k]][indices3[k]]?
            if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
                iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
            {
                IArrayIndexerExpression index3 = (IArrayIndexerExpression)iaie.Indices[0];
                IArrayIndexerExpression iaie2  = (IArrayIndexerExpression)iaie.Target;
                if (index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression &&
                    iaie2.Target is IArrayIndexerExpression &&
                    iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression)
                {
                    IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie2.Indices[0];
                    IArrayIndexerExpression iaie3  = (IArrayIndexerExpression)iaie2.Target;
                    if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
                        iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression)
                    {
                        IArrayIndexerExpression      index      = (IArrayIndexerExpression)iaie3.Indices[0];
                        IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
                        IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                        if (index.Indices.Count == 1 && index2.Indices[0].Equals(innerIndex) &&
                            index3.Indices[0].Equals(innerIndex) &&
                            innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
                        {
                            // expression has the form array[indices[k]][indices2[k]][indices3[k]]
                            if (isDef)
                            {
                                Error("fancy indexing not allowed on left hand side");
                                return(iaie);
                            }
                            WarnIfLocal(index.Target, iaie3.Target, iaie);
                            WarnIfLocal(index2.Target, iaie3.Target, iaie);
                            WarnIfLocal(index3.Target, iaie3.Target, iaie);
                            containers = RemoveReferencesTo(containers, innerIndex);
                            IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
                            var         indices  = Recognizer.GetIndices(iaie);
                            // Build name of replacement variable from index values
                            StringBuilder sb = new StringBuilder("_item");
                            AppendIndexString(sb, iaie3);
                            AppendIndexString(sb, iaie2);
                            AppendIndexString(sb, iaie);
                            string name = ToString(iaie3.Target) + sb.ToString();
                            VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                            newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
                            if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                            {
                                context.InputAttributes.Set(newvd, new DerivedVariable());
                            }
                            IExpression getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
                                                                               new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
                            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                            stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                            newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
                            rhsExpr = getItems;
                        }
                    }
                }
            }
            // does the expression have the form array[indices[k]][indices2[k]]?
            if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
                iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
            {
                IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie.Indices[0];
                IArrayIndexerExpression target = (IArrayIndexerExpression)iaie.Target;
                if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
                    target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression)
                {
                    IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index2.Indices[0];
                    IArrayIndexerExpression      index      = (IArrayIndexerExpression)target.Indices[0];
                    IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                    if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) &&
                        innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target))
                    {
                        // expression has the form array[indices[k]][indices2[k]]
                        if (isDef)
                        {
                            Error("fancy indexing not allowed on left hand side");
                            return(iaie);
                        }
                        var innerLoops = new List <IForStatement>();
                        innerLoops.Add(innerLoop);
                        var indexTarget  = index.Target;
                        var index2Target = index2.Target;
                        // check if the index array is jagged, i.e. array[indices[k][j]]
                        while (indexTarget is IArrayIndexerExpression && index2Target is IArrayIndexerExpression)
                        {
                            IArrayIndexerExpression indexTargetExpr  = (IArrayIndexerExpression)indexTarget;
                            IArrayIndexerExpression index2TargetExpr = (IArrayIndexerExpression)index2Target;
                            if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression &&
                                index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression)
                            {
                                IVariableReferenceExpression innerIndexTarget  = (IVariableReferenceExpression)indexTargetExpr.Indices[0];
                                IVariableReferenceExpression innerIndex2Target = (IVariableReferenceExpression)index2TargetExpr.Indices[0];
                                IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget);
                                if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
                                    innerIndexTarget.Equals(innerIndex2Target))
                                {
                                    innerLoops.Add(indexTargetLoop);
                                    indexTarget  = indexTargetExpr.Target;
                                    index2Target = index2TargetExpr.Target;
                                }
                                else
                                {
                                    break;
                                }
                            }
                            else
                            {
                                break;
                            }
                        }
                        WarnIfLocal(indexTarget, target.Target, iaie);
                        WarnIfLocal(index2Target, target.Target, iaie);
                        innerLoops.Reverse();
                        var loopSizes    = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
                        var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
                        // Build name of replacement variable from index values
                        StringBuilder sb = new StringBuilder("_item");
                        AppendIndexString(sb, target);
                        AppendIndexString(sb, iaie);
                        string name = ToString(target.Target) + sb.ToString();
                        VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                        var indices = Recognizer.GetIndices(iaie);
                        newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
                        if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                        {
                            context.InputAttributes.Set(newvd, new DerivedVariable());
                        }
                        IExpression getItems;
                        if (innerLoops.Count == 1)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else if (innerLoops.Count == 2)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else
                        {
                            throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
                        }
                        context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                        stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                        var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
                        newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
                        rhsExpr = getItems;
                    }
                    else if (HasAnyCommonLoops(index, index2))
                    {
                        Warning($"This model will consume excess memory due to the indexing expression {iaie} since {index} and {index2} have larger depth than the compiler can handle.");
                    }
                }
            }
            if (newvd == null)
            {
                IArrayIndexerExpression originalExpr = iaie;
                if (UseGetItems)
                {
                    iaie = (IArrayIndexerExpression)base.ConvertArrayIndexer(iaie);
                }
                if (!object.ReferenceEquals(iaie.Target, originalExpr.Target) && false)
                {
                    // TODO: determine if this warning is useful or not
                    string warningText = "This model may consume excess memory due to the jagged indexing expression {0}";
                    Warning(string.Format(warningText, originalExpr));
                }

                // get the baseVar of the new expression.
                baseVar = Recognizer.GetVariableDeclaration(iaie);
                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);

                var indices = Recognizer.GetIndices(iaie);
                // Build name of replacement variable from index values
                StringBuilder sb = new StringBuilder("_item");
                AppendIndexString(sb, iaie);
                string name = ToString(iaie.Target) + sb.ToString();

                // does the expression have the form array[indices[k]]?
                if (UseGetItems && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
                {
                    IArrayIndexerExpression index = (IArrayIndexerExpression)iaie.Indices[0];
                    if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression)
                    {
                        // expression has the form array[indices[k]]
                        IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
                        IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                        if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie.Target, index.Target))
                        {
                            if (isDef)
                            {
                                Error("fancy indexing not allowed on left hand side");
                                return(iaie);
                            }
                            var innerLoops = new List <IForStatement>();
                            innerLoops.Add(innerLoop);
                            var indexTarget = index.Target;
                            // check if the index array is jagged, i.e. array[indices[k][j]]
                            while (indexTarget is IArrayIndexerExpression)
                            {
                                IArrayIndexerExpression index2 = (IArrayIndexerExpression)indexTarget;
                                if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression)
                                {
                                    IVariableReferenceExpression innerIndex2 = (IVariableReferenceExpression)index2.Indices[0];
                                    IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2);
                                    if (innerLoop2 != null && AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target))
                                    {
                                        innerLoops.Add(innerLoop2);
                                        indexTarget = index2.Target;
                                        // This limit must match the number of handled cases below.
                                        if (innerLoops.Count == 3)
                                        {
                                            break;
                                        }
                                    }
                                    else
                                    {
                                        break;
                                    }
                                }
                                else
                                {
                                    break;
                                }
                            }
                            WarnIfLocal(indexTarget, iaie.Target, originalExpr);
                            innerLoops.Reverse();
                            var loopSizes    = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
                            var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
                            newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
                            if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                            {
                                context.InputAttributes.Set(newvd, new DerivedVariable());
                            }
                            IExpression getItems;
                            if (innerLoops.Count == 1)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else if (innerLoops.Count == 2)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else if (innerLoops.Count == 3)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else
                            {
                                throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
                            }
                            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                            stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                            var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
                            newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
                            rhsExpr = getItems;
                        }
                    }
                }
                if (newvd == null)
                {
                    if (UseGetItems && info.count < 2)
                    {
                        return(iaie);
                    }
                    try
                    {
                        newvd = varInfo.DeriveIndexedVariable(stmts, context, name, indices, copyInitializer: isDef);
                    }
                    catch (Exception ex)
                    {
                        Error(ex.Message, ex);
                        return(iaie);
                    }
                    context.OutputAttributes.Remove <DerivedVariable>(newvd);
                    newExpr = Builder.VarRefExpr(newvd);
                    rhsExpr = iaie;
                    if (isDef)
                    {
                        // change:
                        //   array[0] = f()
                        // into:
                        //   array_item0 = f()
                        //   array[0] = Copy(array_item0)
                        IExpression copy   = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, newExpr);
                        IStatement  copySt = Builder.AssignStmt(iaie, copy);
                        stmtsAfter.Add(copySt);
                        if (!context.InputAttributes.Has <DerivedVariable>(baseVar))
                        {
                            context.InputAttributes.Set(baseVar, new DerivedVariable());
                        }
                    }
                    else if (!info.IsAssignedTo)
                    {
                        // change:
                        //   x = f(array[0])
                        // into:
                        //   array_item0 = Copy(array[0])
                        //   x = f(array_item0)
                        IExpression copy   = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, iaie);
                        IStatement  copySt = Builder.AssignStmt(Builder.VarRefExpr(newvd), copy);
                        //if (attr != null) context.OutputAttributes.Set(copySt, attr);
                        stmts.Add(copySt);
                        context.InputAttributes.Set(newvd, new DerivedVariable());
                    }
                }
            }

            // Reduce memory consumption by declaring the clone outside of unnecessary loops.
            // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item.
            containers = Containers.RemoveUnusedLoops(containers, context, rhsExpr);
            if (context.InputAttributes.Has <DoNotSendEvidence>(originalBaseVar))
            {
                containers = Containers.RemoveStochasticConditionals(containers, context);
            }
            if (true)
            {
                IStatement st = GetBindingSetContainer(FilterBindingSet(info.bindings,
                                                                        binding => Containers.ContainsExpression(containers.inputs, context, binding.GetExpression())));
                if (st != null)
                {
                    containers.Add(st);
                }
            }
            // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible,
            // then wrap the declaration with the remaining containers.
            int        ancIndex = containers.GetMatchingAncestorIndex(context);
            Containers missing  = containers.GetContainersNotInContext(context, ancIndex);

            stmts = Containers.WrapWithContainers(stmts, missing.outputs);
            context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            stmtsAfter = Containers.WrapWithContainers(stmtsAfter, missing.outputs);
            context.AddStatementsAfterAncestorIndex(ancIndex, stmtsAfter);
            context.InputAttributes.Set(newvd, containers);
            info.clone = newExpr;
            return(newExpr);
        }
Exemple #13
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);
            // Check if this is an index local variable
            if (baseVar == null) return expr;
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr;

            // Get the loop context for this variable
            LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar);
            if (lc == null)
            {
                Error("Loop context not found for '" + baseVar.Name + "'.");
                return expr;
            }

            // Get the reference loop context for this expression
            RefLoopContext rlc = lc.GetReferenceLoopContext(context);
            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.loops.Count == 0) return expr;

            // the set of loop variables that are constant wrt the expr
            Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>();
            constantLoopVars.AddRange(lc.loopVariables);

            // collect set of all loop variable indices in the expression
            Set<int> embeddedLoopIndices = new Set<int>();
            List<IList<IExpression>> brackets = Recognizer.GetIndices(expr);
            foreach (IList<IExpression> bracket in brackets)
            {
                foreach (IExpression index in bracket)
                {
                    IExpression indExpr = index;
                    if (indExpr is IBinaryExpression ibe)
                    {
                        indExpr = ibe.Left;
                    }
                    IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr);
                    if (indVar != null)
                    {
                        if (!constantLoopVars.Contains(indVar))
                        {
                            int loopIndex = rlc.loopVariables.IndexOf(indVar);
                            if (loopIndex != -1)
                            {
                                // indVar is a loop variable
                                constantLoopVars.Add(rlc.loopVariables[loopIndex]);
                            }
                            else 
                            {
                                // indVar is not a loop variable
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar);
                                foreach (var ivd in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                    else
                    {
                        foreach(var ivd in Recognizer.GetVariables(indExpr))
                        {
                            if (!constantLoopVars.Contains(ivd))
                            {
                                // copied from above
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
                                foreach (var ivd2 in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd2))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd2);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd2} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                }
            }

            // Find loop variables that must be constant due to condition statements.
            List<IStatement> ancestors = context.FindAncestors<IStatement>();
            foreach (IStatement ancestor in ancestors)
            {
                if (!(ancestor is IConditionStatement ics))
                    continue;
                ConditionBinding binding = new ConditionBinding(ics.Condition);
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs);
                IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs);
                int index = rlc.loopVariables.IndexOf(ivd);
                if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars))
                {
                    constantLoopVars.Add(ivd);
                    continue;
                }
                int index2 = rlc.loopVariables.IndexOf(ivd2);
                if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars))
                {
                    constantLoopVars.Add(ivd2);
                    continue;
                }
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get<Containers>(baseVar);

            IExpression originalExpr = expr;

            for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++)
            {
                IVariableDeclaration loopVar = rlc.loopVariables[currentLoop];
                if (constantLoopVars.Contains(loopVar))
                    continue;
                IForStatement loop = rlc.loops[currentLoop];
                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a loop.  Variables on the left hand side of an assignment must be indexed by all containing loops.");
                    continue;
                }
                if (embeddedLoopIndices.Contains(currentLoop))
                {
                    string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
                    Warning(string.Format(warningText, originalExpr, loopVar.Name));
                }
                // split expr into a target and extra indices, where target will be replicated and extra indices will be added later
                var extraIndices = new List<IEnumerable<IExpression>>();
                AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate);

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IExpression loopSize = Recognizer.LoopSizeExpression(loop);
                IList<IStatement> stmts = Builder.StmtCollection();
                List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate);
                IVariableDeclaration newIndexVar = loopVar;
                // if loopVar is already an indexVar of varInfo, create a new variable
                if (varInfo.HasIndexVar(loopVar))
                {
                    newIndexVar = VariableInformation.GenerateLoopVar(context, "_a");
                    context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar);
                }
                IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"),
                                                                          loopSize, newIndexVar, inds, useArrays: true);
                if (!context.InputAttributes.Has<DerivedVariable>(repVar))
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                if (context.InputAttributes.Has<ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate),
                    new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // RemoveUnusedLoops will also remove conditionals involving those loop variables.
                // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms.
                containers = Containers.RemoveUnusedLoops(containers, context, repMethod);
                if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context);
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                if (containers.Contains(loop))
                {
                    Error("Internal: invalid containers for replicating " + baseVar);
                    break;
                }
                int ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                context.OutputAttributes.Set(repVar, containers);
                List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex);
                foreach (IStatement container in missing.inputs)
                {
                    if (container is IForStatement ifs) loops.Add(ifs);
                }
                context.OutputAttributes.Set(repVar, new LoopContext(loops));
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar));
                expr = Builder.JaggedArrayIndex(expr, extraIndices);
            }

            return expr;
        }
Exemple #14
0
        protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            if (rhs is IArrayCreateExpression)
            {
                IArrayCreateExpression iace = (IArrayCreateExpression)rhs;
                bool zeroLength             = iace.Dimensions.All(dimExpr =>
                                                                  (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0));
                if (!zeroLength && iace.Initializer == null)
                {
                    return; // variable will have assignments to elements
                }
            }
            bool firstTime = !variablesAssigned.Contains(ivd);

            variablesAssigned.Add(ivd);
            bool isInferred   = context.InputAttributes.Has <IsInferred>(ivd);
            bool isStochastic = CodeRecognizer.IsStochastic(context, ivd);

            if (!isStochastic)
            {
                return;
            }
            VariableInformation vi            = VariableInformation.GetVariableInformation(context, ivd);
            Containers          defContainers = context.InputAttributes.Get <Containers>(ivd);
            int        ancIndex = defContainers.GetMatchingAncestorIndex(context);
            Containers missing  = defContainers.GetContainersNotInContext(context, ancIndex);
            // definition of a stochastic variable
            IExpression lhs = target;

            if (lhs is IVariableDeclarationExpression)
            {
                lhs = Builder.VarRefExpr(ivd);
            }
            IExpression defExpr = lhs;

            if (firstTime && isStochastic)
            {
                // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform
                ChannelInfo defChannel = ChannelInfo.DefChannel(vi);
                defChannel.decl = ivd;
                context.OutputAttributes.Set(ivd, defChannel);
            }
            bool       isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
            IAlgorithm algorithm = this.algorithmDefault;
            Algorithm  algAttr   = context.InputAttributes.Get <Algorithm>(ivd);

            if (algAttr != null)
            {
                algorithm = algAttr.algorithm;
            }
            if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime)
            {
                vi.DefineAllIndexVars(context);
                IList <IStatement>   stmts     = Builder.StmtCollection();
                IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv");
                context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl));
                ChannelInfo derivChannel = ChannelInfo.DefChannel(vi);
                derivChannel.decl = derivDecl;
                context.OutputAttributes.Set(derivChannel.decl, derivChannel);
                context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'"));
                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd);

            if (this.analysis.variablesExcludingVariableFactor.Contains(ivd))
            {
                this.variablesLackingVariableFactor.Add(ivd);
                // ivd will get a marginal channel in ConvertMethodInvoke
                useOfVariable[ivd] = ivd;
                return;
            }
            if (isDerived && !isInferred && !isPointEstimate)
            {
                return;
            }

            IExpression useExpr2 = null;

            if (firstTime)
            {
                // create marginal and use channels
                vi.DefineAllIndexVars(context);
                IList <IStatement> stmts = Builder.StmtCollection();

                CreateMarginalChannel(ivd, vi, stmts);
                if (isStochastic)
                {
                    CreateUseChannel(ivd, vi, stmts);
                    context.InputAttributes.Set(useOfVariable[ivd], defContainers);
                }

                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            if (isStochastic && !useOfVariable.ContainsKey(ivd))
            {
                Error("cannot find use channel of " + ivd);
                return;
            }
            IExpression  marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]);
            IExpression  useExpr      = isStochastic ? Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]) : marginalExpr;
            InitialiseTo it           = context.InputAttributes.Get <InitialiseTo>(ivd);

            Type[] genArgs = new Type[] { defExpr.GetExpressionType() };
            if (rhs is IMethodInvokeExpression)
            {
                IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)) && ancIndex < context.InputStack.Count - 2)
                {
                    IExpression          arg  = imie.Arguments[0];
                    IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg);
                    if (ivd2 != null && context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2))
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = arg;
                        shouldDelete = true;
                        bool makeClone = false;
                        if (makeClone)
                        {
                            VariableInformation         vi2      = VariableInformation.GetVariableInformation(context, ivd2);
                            IList <IStatement>          stmts    = Builder.StmtCollection();
                            List <IList <IExpression> > indices  = Recognizer.GetIndices(defExpr);
                            IVariableDeclaration        useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices);
                            useExpr2 = Builder.VarRefExpr(useDecl2);
                            Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2);
                            int        ancIndex2      = defContainers2.GetMatchingAncestorIndex(context);
                            Containers missing2       = defContainers2.GetContainersNotInContext(context, ancIndex2);
                            stmts = Containers.WrapWithContainers(stmts, missing2.outputs);
                            context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts);
                            context.InputAttributes.Set(useDecl2, defContainers2);

                            // TODO: call CreateUseChannel
                            ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2);
                            usageChannel.decl = useDecl2;
                            context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.OutputAttributes.Set(useDecl2, usageChannel);
                            //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'"));
                            context.OutputAttributes.Remove <InitialiseTo>(vi.declaration);

                            IExpression copyExpr = Builder.StaticGenericMethod(
                                new Func <PlaceHolder, PlaceHolder>(Clone.Copy), genArgs, useExpr2);
                            var copyStmt = Builder.AssignStmt(useExpr, copyExpr);
                            context.AddStatementAfterCurrent(copyStmt);
                        }
                    }
                }
            }

            // Add the variable factor
            IExpression variableFactorExpr;
            bool        isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd);

            if (isGateExitRandom)
            {
                variableFactorExpr = Builder.StaticGenericMethod(
                    new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable),
                    genArgs, defExpr, marginalExpr);
            }
            else
            {
                Delegate d = algorithm.GetVariableFactor(isDerived, it != null);
                if (isPointEstimate)
                {
                    d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint);
                }
                if (it == null)
                {
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr);
                }
                else
                {
                    IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression);
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr);
                }
            }
            context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr);
            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr);
            if (isStochastic)
            {
                context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor());
            }
            var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr);

            context.AddStatementAfterCurrent(assignStmt);
        }
Exemple #15
0
        /// <summary>
        /// This method does all the work of converting literal indexing expressions.
        /// </summary>
        /// <param name="iaie"></param>
        /// <returns></returns>
        protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie)
        {
            IndexAnalysisTransform.IndexInfo info;
            if (!analysis.indexInfoOf.TryGetValue(iaie, out info))
            {
                return(base.ConvertArrayIndexer(iaie));
            }
            // Determine if this is a definition i.e. the variable is on the left hand side of an assignment
            // This must be done before base.ConvertArrayIndexer changes the expression!
            bool isDef = Recognizer.IsBeingMutated(context, iaie);

            if (info.clone != null)
            {
                if (isDef)
                {
                    // check that extra literal indices in the target are zero.
                    // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i])
                    // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1])
                    // but not x_uses[i] = Rep(x[i]) since this will be a duplicate.
                    bool   extraLiteralsAreZero = true;
                    int    parentIndex          = context.InputStack.Count - 2;
                    object parent = context.GetAncestor(parentIndex);
                    while (parent is IArrayIndexerExpression parent_iaie)
                    {
                        foreach (IExpression index in parent_iaie.Indices)
                        {
                            if (index is ILiteralExpression ile)
                            {
                                int value = (int)ile.Value;
                                if (value != 0)
                                {
                                    extraLiteralsAreZero = false;
                                    break;
                                }
                            }
                        }
                        parentIndex--;
                        parent = context.GetAncestor(parentIndex);
                    }
                    if (false && extraLiteralsAreZero)
                    {
                        // change:
                        //   array[0] = f()
                        // into:
                        //   array_item0 = f()
                        //   array[0] = Copy(array_item0)
                        IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { iaie.GetExpressionType() },
                                                                       info.clone);
                        IStatement copySt = Builder.AssignStmt(iaie, copy);
                        context.AddStatementAfterCurrent(copySt);
                    }
                }
                return(info.clone);
            }

            if (isDef)
            {
                // do not clone the lhs of an array create assignment.
                IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>();
                if (assignExpr.Expression is IArrayCreateExpression)
                {
                    return(iaie);
                }
            }

            IVariableDeclaration originalBaseVar = Recognizer.GetVariableDeclaration(iaie);

            // If the variable is not stochastic, return
            if (!CodeRecognizer.IsStochastic(context, originalBaseVar))
            {
                return(iaie);
            }

            IExpression          newExpr    = null;
            IVariableDeclaration baseVar    = originalBaseVar;
            IVariableDeclaration newvd      = null;
            IExpression          rhsExpr    = null;
            Containers           containers = info.containers;
            Type tp = iaie.GetExpressionType();

            if (tp == null)
            {
                Error("Could not determine type of expression: " + iaie);
                return(iaie);
            }
            var stmts      = Builder.StmtCollection();
            var stmtsAfter = Builder.StmtCollection();

            // does the expression have the form array[indices[k]][indices2[k]][indices3[k]]?
            if (newvd == null && UseGetItems && iaie.Indices.Count == 1)
            {
                if (iaie.Target is IArrayIndexerExpression iaie2 &&
                    iaie.Indices[0] is IArrayIndexerExpression index3 &&
                    index3.Indices.Count == 1 &&
                    index3.Indices[0] is IVariableReferenceExpression innerIndex3 &&
                    iaie2.Target is IArrayIndexerExpression iaie3 &&
                    iaie2.Indices.Count == 1 &&
                    iaie2.Indices[0] is IArrayIndexerExpression index2 &&
                    index2.Indices.Count == 1 &&
                    index2.Indices[0] is IVariableReferenceExpression innerIndex2 &&
                    innerIndex2.Equals(innerIndex3) &&
                    iaie3.Indices.Count == 1 &&
                    iaie3.Indices[0] is IArrayIndexerExpression index &&
                    index.Indices.Count == 1 &&
                    index.Indices[0] is IVariableReferenceExpression innerIndex &&
                    innerIndex.Equals(innerIndex2))
                {
                    IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                    if (innerLoop != null &&
                        AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
                    {
                        // expression has the form array[indices[k]][indices2[k]][indices3[k]]
                        if (isDef)
                        {
                            Error("fancy indexing not allowed on left hand side");
                            return(iaie);
                        }
                        WarnIfLocal(index.Target, iaie3.Target, iaie);
                        WarnIfLocal(index2.Target, iaie3.Target, iaie);
                        WarnIfLocal(index3.Target, iaie3.Target, iaie);
                        containers = RemoveReferencesTo(containers, innerIndex);
                        IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
                        var         indices  = Recognizer.GetIndices(iaie);
                        // Build name of replacement variable from index values
                        StringBuilder sb = new StringBuilder("_item");
                        AppendIndexString(sb, iaie3);
                        AppendIndexString(sb, iaie2);
                        AppendIndexString(sb, iaie);
                        string name = ToString(iaie3.Target) + sb.ToString();
                        VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                        newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
                        if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                        {
                            context.InputAttributes.Set(newvd, new DerivedVariable());
                        }
                        IExpression getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
                                                                           new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
                        context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                        stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                        newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
                        rhsExpr = getItems;
                    }
                }
            }
            // does the expression have the form array[indices[k]][indices2[k]]?
            if (newvd == null && UseGetItems && iaie.Indices.Count == 1)
            {
                if (iaie.Target is IArrayIndexerExpression target &&
                    iaie.Indices[0] is IArrayIndexerExpression index2 &&
                    index2.Indices.Count == 1 &&
                    index2.Indices[0] is IVariableReferenceExpression innerIndex &&
                    target.Indices.Count == 1 &&
                    target.Indices[0] is IArrayIndexerExpression index)
                {
                    IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                    if (index.Indices.Count == 1 &&
                        index.Indices[0].Equals(innerIndex) &&
                        innerLoop != null &&
                        AreLoopsDisjoint(innerLoop, target.Target, index.Target))
                    {
                        // expression has the form array[indices[k]][indices2[k]]
                        if (isDef)
                        {
                            Error("fancy indexing not allowed on left hand side");
                            return(iaie);
                        }
                        var innerLoops = new List <IForStatement>();
                        innerLoops.Add(innerLoop);
                        var indexTarget  = index.Target;
                        var index2Target = index2.Target;
                        // check if the index array is jagged, i.e. array[indices[k][j]]
                        while (indexTarget is IArrayIndexerExpression indexTargetExpr &&
                               index2Target is IArrayIndexerExpression index2TargetExpr)
                        {
                            if (indexTargetExpr.Indices.Count == 1 &&
                                indexTargetExpr.Indices[0] is IVariableReferenceExpression innerIndexTarget &&
                                index2TargetExpr.Indices.Count == 1 &&
                                index2TargetExpr.Indices[0] is IVariableReferenceExpression innerIndex2Target)
                            {
                                IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget);
                                if (indexTargetLoop != null &&
                                    AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
                                    innerIndexTarget.Equals(innerIndex2Target))
                                {
                                    innerLoops.Add(indexTargetLoop);
                                    indexTarget  = indexTargetExpr.Target;
                                    index2Target = index2TargetExpr.Target;
                                }
                                else
                                {
                                    break;
                                }
                            }
                            else
                            {
                                break;
                            }
                        }
                        WarnIfLocal(indexTarget, target.Target, iaie);
                        WarnIfLocal(index2Target, target.Target, iaie);
                        innerLoops.Reverse();
                        var loopSizes    = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
                        var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
                        // Build name of replacement variable from index values
                        StringBuilder sb = new StringBuilder("_item");
                        AppendIndexString(sb, target);
                        AppendIndexString(sb, iaie);
                        string name = ToString(target.Target) + sb.ToString();
                        VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                        var indices = Recognizer.GetIndices(iaie);
                        newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
                        if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                        {
                            context.InputAttributes.Set(newvd, new DerivedVariable());
                        }
                        IExpression getItems;
                        if (innerLoops.Count == 1)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else if (innerLoops.Count == 2)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else
                        {
                            throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
                        }
                        context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                        stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                        var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
                        newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
                        rhsExpr = getItems;
                    }
                    else if (HasAnyCommonLoops(index, index2))
                    {
                        Warning($"This model will consume excess memory due to the indexing expression {iaie} since {index} and {index2} have larger depth than the compiler can handle.");
                    }
                }
Exemple #16
0
        /// <summary>
        /// Add expr to currentBlock.variablesDefined or currentBlock.variablesUsed
        /// </summary>
        /// <param name="expr"></param>
        /// <param name="isDef"></param>
        /// <param name="conditionContext"></param>
        protected void ProcessUse(IExpression expr, bool isDef, List <ConditionBinding> conditionContext)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr);

            if (ivd == null)
            {
                return;
            }
            if (!CodeRecognizer.IsStochastic(context, ivd))
            {
                return;
            }
            if (gateBlockContext.Count == 0)
            {
                return;
            }
            GateBlock currentBlock   = gateBlockContext[gateBlockContext.Count - 1];
            GateBlock gateBlockOfVar = context.InputAttributes.Get <GateBlock>(ivd);

            if (gateBlockOfVar == currentBlock)
            {
                return;                                 // local variable of the gateBlock
            }
            ExpressionWithBindings eb = new ExpressionWithBindings();

            eb.Expression = ReplaceLocalIndices(currentBlock, expr);
            List <ConditionBinding> bindings = FilterConditionContext(currentBlock, conditionContext);

            if (bindings.Count > 0)
            {
                eb.Bindings.Add(bindings);
            }
            //eb.Containers = Containers.InsideOf(context, GetAncestorIndexOfGateBlock(currentBlock));
            if (isDef)
            {
                ExpressionWithBindings eb2;
                if (!currentBlock.variablesDefined.TryGetValue(ivd, out eb2))
                {
                    currentBlock.variablesDefined[ivd] = eb;
                }
                else
                {
                    // all definitions of the same variable must have a common parent
                    currentBlock.variablesDefined[ivd] = GetCommonParent(eb, eb2);
                }
            }
            else
            {
                List <ExpressionWithBindings> ebs;
                if (!currentBlock.variablesUsed.TryGetValue(ivd, out ebs))
                {
                    ebs = new List <ExpressionWithBindings>();
                    ebs.Add(eb);
                    currentBlock.variablesUsed[ivd] = ebs;
                }
                else
                {
                    // collect all uses that overlap with eb, and replace with their common parent
                    List <ExpressionWithBindings> notOverlapping = new List <ExpressionWithBindings>();
                    while (true)
                    {
                        foreach (ExpressionWithBindings eb2 in ebs)
                        {
                            ExpressionWithBindings parent = GetCommonParent(eb, eb2);
                            if (CouldOverlap(eb, eb2))
                            {
                                eb = parent;
                            }
                            else
                            {
                                notOverlapping.Add(eb2);
                            }
                        }
                        if (notOverlapping.Count == ebs.Count)
                        {
                            break;                                    // nothing overlaps
                        }
                        // eb must have changed, so try again using the new eb
                        ebs.Clear();
                        ebs.AddRange(notOverlapping);
                        notOverlapping.Clear();
                    }
                    ebs.Add(eb);
                    currentBlock.variablesUsed[ivd] = ebs;
                }
            }
        }
Exemple #17
0
        private void RegisterDepth(IExpression expr, bool isDef)
        {
            if (Recognizer.IsBeingIndexed(context))
            {
                return;
            }
            int depth = Recognizer.GetIndexingDepth(expr);
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);

            // If not an indexed variable reference, skip it (e.g. an indexed argument reference)
            if (baseVar == null)
            {
                return;
            }
            // If the variable is not stochastic, skip it
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return;
            }
            ChannelInfo ci = context.InputAttributes.Get <ChannelInfo>(baseVar);

            if (ci != null && ci.IsMarginal)
            {
                return;
            }
            DepthInfo depthInfo;

            if (!depthInfos.TryGetValue(baseVar, out depthInfo))
            {
                depthInfo           = new DepthInfo();
                depthInfos[baseVar] = depthInfo;
            }
            if (isDef)
            {
                depthInfo.definitionDepth = depth;
                return;
            }
            depthInfo.useCount++;
            if (depth < depthInfo.minDepth)
            {
                depthInfo.minDepth = depth;
            }
            int literalIndexingDepth = 0;

            foreach (var bracket in Recognizer.GetIndices(expr))
            {
                if (!bracket.All(index => index is ILiteralExpression))
                {
                    break;
                }
                literalIndexingDepth++;
            }
            IndexInfo info;

            if (depthInfo.indexInfoOfDepth.TryGetValue(depth, out info))
            {
                Containers containers = new Containers(context);
                info.containers           = Containers.Intersect(info.containers, containers);
                info.literalIndexingDepth = System.Math.Min(info.literalIndexingDepth, literalIndexingDepth);
            }
            else
            {
                info                              = new IndexInfo();
                info.containers                   = new Containers(context);
                info.literalIndexingDepth         = literalIndexingDepth;
                depthInfo.indexInfoOfDepth[depth] = info;
            }
        }
Exemple #18
0
        protected override IStatement ConvertCondition(IConditionStatement ics)
        {
            context.SetPrimaryOutput(ics);
            ConvertExpression(ics.Condition);
            ConditionBinding binding   = GateTransform.GetConditionBinding(ics.Condition, context, out IForStatement loop);
            IExpression      caseValue = binding.rhs;

            if (!GateTransform.IsLiteralOrLoopVar(context, caseValue, out loop))
            {
                Error("If statement condition must compare to a literal or loop counter, was: " + ics.Condition);
                return(ics);
            }
            bool        isStochastic = CodeRecognizer.IsStochastic(context, binding.lhs);
            IExpression gateBlockKey;

            if (isStochastic)
            {
                gateBlockKey = binding.lhs;
            }
            else
            {
                // definitions must not be unified across deterministic gate conditions
                gateBlockKey = binding.GetExpression();
            }
            GateBlock gateBlock             = null;
            Set <ConditionBinding> bindings = ConditionBinding.Copy(conditionContext);
            Dictionary <IExpression, GateBlock> blockMap;

            if (!gateBlocks.TryGetValue(bindings, out blockMap))
            {
                // first time seeing these bindings
                blockMap             = new Dictionary <IExpression, GateBlock>();
                gateBlocks[bindings] = blockMap;
            }
            if (!blockMap.TryGetValue(gateBlockKey, out gateBlock))
            {
                // first time seeing this lhs
                gateBlock = new GateBlock();
                blockMap[gateBlockKey] = gateBlock;
            }
            if (gateBlock.hasLoopCaseValue && loop == null)
            {
                Error("Cannot compare " + binding.lhs + " to a literal, since it was previously compared to a loop counter.  Put this test inside the loop.");
            }
            if (!gateBlock.hasLoopCaseValue && gateBlock.caseValues.Count > 0 && loop != null)
            {
                Error("Cannot compare " + binding.lhs + " to a loop counter, since it was previously compared to a literal.  Put the literal case inside the loop.");
            }
            gateBlock.caseValues.Add(caseValue);
            if (loop != null)
            {
                gateBlock.hasLoopCaseValue = true;
            }
            gateBlockContext.Add(gateBlock);
            context.OutputAttributes.Set(ics, gateBlock);
            int startIndex = conditionContext.Count;

            conditionContext.Add(binding);
            ConvertBlock(ics.Then);
            if (ics.Else != null)
            {
                conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
                binding = binding.FlipCondition();
                conditionContext.Add(binding);
                ConvertBlock(ics.Else);
            }
            conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex);
            gateBlockContext.RemoveAt(gateBlockContext.Count - 1);
            // remove any uses that match a def
            //RemoveUsesOfDefs(gateBlock);
            if (gateBlockContext.Count > 0)
            {
                GateBlock currentBlock = gateBlockContext[gateBlockContext.Count - 1];
                // all variables defined/used in the inner block must be processed by the outer block
                foreach (ExpressionWithBindings eb in gateBlock.variablesDefined.Values)
                {
                    if (eb.Bindings.Count > 0)
                    {
                        foreach (List <ConditionBinding> binding2 in eb.Bindings)
                        {
                            ProcessUse(eb.Expression, true, Union(conditionContext, binding2));
                        }
                    }
                    else
                    {
                        ProcessUse(eb.Expression, true, conditionContext);
                    }
                }
                foreach (List <ExpressionWithBindings> ebs in gateBlock.variablesUsed.Values)
                {
                    foreach (ExpressionWithBindings eb in ebs)
                    {
                        if (eb.Bindings.Count > 0)
                        {
                            foreach (ICollection <ConditionBinding> binding2 in eb.Bindings)
                            {
                                ProcessUse(eb.Expression, false, Union(conditionContext, binding2));
                            }
                        }
                        else
                        {
                            ProcessUse(eb.Expression, false, conditionContext);
                        }
                    }
                }
            }
            return(ics);
        }
        protected void ProcessDefinition(IExpression expr, IVariableDeclaration targetVar, bool isLhs)
        {
            bool targetIsPointMass = false;

            if (expr is IMethodInvokeExpression imie)
            {
                // TODO: consider using a method attribute for this
                if (Recognizer.IsStaticGenericMethod(imie, new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint))
                    )
                {
                    targetIsPointMass = true;
                }
                else
                {
                    FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie);
                    targetIsPointMass = info.IsDeterministicFactor && (
                        (info.ReturnedInAllElementsParameterIndex != -1 && ArgumentIsPointMass(imie.Arguments[info.ReturnedInAllElementsParameterIndex])) ||
                        imie.Arguments.All(ArgumentIsPointMass)
                        );
                }
                if (targetIsPointMass)
                {
                    // do this immediately so all uses are updated
                    if (!context.InputAttributes.Has <ForwardPointMass>(targetVar))
                    {
                        context.OutputAttributes.Set(targetVar, new ForwardPointMass());
                    }
                    // the rest is done later
                    List <IMethodInvokeExpression> list;
                    if (!variablesDefinedPointMass.TryGetValue(targetVar, out list))
                    {
                        list = new List <IMethodInvokeExpression>();
                        variablesDefinedPointMass.Add(targetVar, list);
                    }
                    // this code needs to be synchronized with MessageTransform.ConvertMethodInvoke
                    if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, int, PlaceHolder[]>(Clone.Replicate)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged)) ||
                        Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged))
                        )
                    {
                        list.Add(imie);
                    }
                }
            }
            if (!targetIsPointMass && !(expr is IArrayCreateExpression))
            {
                variablesDefinedNonPointMass.Add(targetVar);
                if (variablesDefinedPointMass.ContainsKey(targetVar))
                {
                    variablesDefinedPointMass.Remove(targetVar);
                    context.OutputAttributes.Remove <ForwardPointMass>(targetVar);
                }
            }

            bool ArgumentIsPointMass(IExpression arg)
            {
                bool IsOut = (arg is IAddressOutExpression);

                if (CodeRecognizer.IsStochastic(context, arg) && !IsOut)
                {
                    IVariableDeclaration argVar = Recognizer.GetVariableDeclaration(arg);
                    return((argVar != null) && context.InputAttributes.Has <ForwardPointMass>(argVar));
                }
                else
                {
                    return(true);
                }
            }
        }
Exemple #20
0
        /// <summary>
        /// Add expr to currentBlock.variablesDefined or currentBlock.variablesUsed
        /// </summary>
        /// <param name="expr"></param>
        /// <param name="isDef"></param>
        /// <param name="conditionContext"></param>
        protected void ProcessUse(IExpression expr, bool isDef, List <ConditionBinding> conditionContext)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr);

            if (ivd == null)
            {
                return;
            }
            if (!CodeRecognizer.IsStochastic(context, ivd))
            {
                return;
            }
            if (gateBlockContext.Count == 0)
            {
                return;
            }
            GateBlock currentBlock   = gateBlockContext[gateBlockContext.Count - 1];
            GateBlock gateBlockOfVar = context.InputAttributes.Get <GateBlock>(ivd);

            if (gateBlockOfVar == currentBlock)
            {
                return;                                 // local variable of the gateBlock
            }
            ExpressionWithBindings eb = new ExpressionWithBindings(ReplaceLocalIndices(currentBlock, expr), FilterConditionContext(currentBlock));

            if (isDef)
            {
                if (currentBlock.variablesDefined.TryGetValue(ivd, out ExpressionWithBindings eb2))
                {
                    // all definitions of the same variable must have a common parent
                    currentBlock.variablesDefined[ivd] = GetCommonParent(eb, eb2);
                }
                else
                {
                    currentBlock.variablesDefined[ivd] = eb;
                }
            }
            else
            {
                if (currentBlock.variablesUsed.TryGetValue(ivd, out List <ExpressionWithBindings> ebs))
                {
                    // collect all uses that overlap with eb, and replace with their common parent
                    List <ExpressionWithBindings> notOverlapping = new List <ExpressionWithBindings>();
                    while (true)
                    {
                        foreach (ExpressionWithBindings eb2 in ebs)
                        {
                            if (CouldOverlap(eb, eb2, ignoreBindings: GateTransform.DeterministicEnterExit))
                            {
                                eb = GetCommonParent(eb, eb2);
                            }
                            else
                            {
                                notOverlapping.Add(eb2);
                            }
                        }
                        if (notOverlapping.Count == ebs.Count)
                        {
                            break;                                    // nothing overlaps
                        }
                        // eb must have changed, so try again using the new eb
                        ebs.Clear();
                        ebs.AddRange(notOverlapping);
                        notOverlapping.Clear();
                    }
                    ebs.Add(eb);
                    currentBlock.variablesUsed[ivd] = ebs;
                }
                else
                {
                    currentBlock.variablesUsed[ivd] = new List <ExpressionWithBindings> {
                        eb
                    };
                }
            }

            List <ConditionBinding> FilterConditionContext(GateBlock gateBlock)
            {
                return(conditionContext
                       .Where(binding => !CodeRecognizer.IsStochastic(context, binding.lhs) &&
                              !ContainsLocalVars(gateBlock, binding.lhs) &&
                              !ContainsLocalVars(gateBlock, binding.rhs))
                       .ToList());
            }
        }
Exemple #21
0
        // Determine the message type of target from the message type of the factor arguments
        protected void ProcessFactor(IExpression factor, MessageDirection direction)
        {
            NodeInfo info = GetNodeInfo(factor);
            // fill in argumentTypes
            Dictionary <string, Type>        argumentTypes = new Dictionary <string, Type>();
            Dictionary <string, IExpression> arguments     = new Dictionary <string, IExpression>();

            for (int i = 0; i < info.info.ParameterNames.Count; i++)
            {
                string parameterName = info.info.ParameterNames[i];
                // Create message info. 'isForward' says whether the message
                // out is in the forward or backward direction
                bool        isChild    = info.isReturnOrOut[i];
                IExpression arg        = info.arguments[i];
                bool        isConstant = !CodeRecognizer.IsStochastic(context, arg);
                if (isConstant)
                {
                    arguments[parameterName] = arg;
                    Type inwardType = arg.GetExpressionType();
                    argumentTypes[parameterName] = inwardType;
                }
                else if (!isChild)
                {
                    IExpression msgExpr = GetMessageExpression(arg, fwdMessageVars);
                    if (msgExpr == null)
                    {
                        return;
                    }
                    arguments[parameterName] = msgExpr;
                    Type inwardType = msgExpr.GetExpressionType();
                    if (inwardType == null)
                    {
                        Error("inferred an incorrect message type for " + arg);
                        return;
                    }
                    argumentTypes[parameterName] = inwardType;
                }
                else if (direction == MessageDirection.Backwards)
                {
                    IExpression msgExpr = GetMessageExpression(arg, bckMessageVars);
                    if (msgExpr == null)
                    {
                        //Console.WriteLine("creating backward message for "+arg);
                        CreateBackwardMessageFromForward(arg, null);
                        msgExpr = GetMessageExpression(arg, bckMessageVars);
                        if (msgExpr == null)
                        {
                            return;
                        }
                    }
                    arguments[parameterName] = msgExpr;
                    Type inwardType = msgExpr.GetExpressionType();
                    if (inwardType == null)
                    {
                        Error("inferred an incorrect message type for " + arg);
                        return;
                    }
                    argumentTypes[parameterName] = inwardType;
                }
            }
            IAlgorithm alg     = algorithm;
            Algorithm  algAttr = context.InputAttributes.Get <Algorithm>(info.imie);

            if (algAttr != null)
            {
                alg = algAttr.algorithm;
            }
            List <ICompilerAttribute> factorAttributes = context.InputAttributes.GetAll <ICompilerAttribute>(info.imie);
            string methodSuffix = alg.GetOperatorMethodSuffix(factorAttributes);

            // infer types of children
            for (int i = 0; i < info.info.ParameterNames.Count; i++)
            {
                string parameterName = info.info.ParameterNames[i];
                bool   isChild       = info.isReturnOrOut[i];
                if (isChild != (direction == MessageDirection.Forwards))
                {
                    continue;
                }
                IExpression target     = info.arguments[i];
                bool        isConstant = !CodeRecognizer.IsStochastic(context, target);
                if (isConstant)
                {
                    continue;
                }
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);
                if (ivd == null)
                {
                    continue;
                }
                Type           targetType = null;
                MessageFcnInfo fcninfo    = null;
                if (direction == MessageDirection.Forwards)
                {
                    try
                    {
                        fcninfo = GetMessageFcnInfo(info.info, "Init", parameterName, argumentTypes);
                    }
                    catch (Exception)
                    {
                        try
                        {
                            fcninfo = GetMessageFcnInfo(info.info, methodSuffix + "Init", parameterName, argumentTypes);
                        }
                        catch (Exception ex)
                        {
                            //Error("could not determine message type of "+ivd.Name, ex);
                            try
                            {
                                fcninfo = GetMessageFcnInfo(info.info, methodSuffix, parameterName, argumentTypes);
                                if (fcninfo.PassResult)
                                {
                                    throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) +
                                                                     " is not suitable for initialization since it takes a result parameter.  Please provide a separate Init method.");
                                }
                                if (fcninfo.PassResultIndex)
                                {
                                    throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) +
                                                                     " is not suitable for initialization since it takes a resultIndex parameter.  Please provide a separate Init method.");
                                }
                            }
                            catch (Exception ex2)
                            {
                                if (direction == MessageDirection.Forwards)
                                {
                                    Error("could not determine " + direction + " message type of " + ivd.Name + ": " + ex.Message, ex2);
                                    continue;
                                }
                                fcninfo = null;
                            }
                        }
                    }
                    if (fcninfo != null)
                    {
                        targetType = fcninfo.Method.ReturnType;
                        if (targetType.IsGenericParameter)
                        {
                            if (direction == MessageDirection.Forwards)
                            {
                                Error("could not determine " + direction + " message type of " + ivd.Name + " in " + StringUtil.MethodFullNameToString(fcninfo.Method));
                                continue;
                            }
                            fcninfo = null;
                        }
                    }
                    if (fcninfo != null)
                    {
                        VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd);
                        try
                        {
                            targetType = MessageTransform.GetDistributionType(ivd.VariableType.DotNetType, target.GetExpressionType(), targetType, true);
                        }
                        catch (Exception ex)
                        {
                            if (direction == MessageDirection.Forwards)
                            {
                                Error("could not determine " + direction + " message type of " + ivd.Name, ex);
                                continue;
                            }
                            fcninfo = null;
                        }
                    }
                }
                Dictionary <IVariableDeclaration, IVariableDeclaration> messageVars = (direction == MessageDirection.Forwards) ? fwdMessageVars : bckMessageVars;
                if (fcninfo != null)
                {
                    string name = ivd.Name + (direction == MessageDirection.Forwards ? "_F" : "_B");
                    IVariableDeclaration msgVar;
                    if (!messageVars.TryGetValue(ivd, out msgVar))
                    {
                        msgVar = Builder.VarDecl(name, targetType);
                    }
                    if (true)
                    {
                        // construct the init expression
                        List <IExpression> args       = new List <IExpression>();
                        ParameterInfo[]    parameters = fcninfo.Method.GetParameters();
                        foreach (ParameterInfo parameter in parameters)
                        {
                            string argName = parameter.Name;
                            if (IsFactoryType(parameter.ParameterType))
                            {
                                IVariableDeclaration factoryVar = GetFactoryVariable(parameter.ParameterType);
                                args.Add(Builder.VarRefExpr(factoryVar));
                            }
                            else
                            {
                                FactorEdge factorEdge          = fcninfo.factorEdgeOfParameter[parameter.Name];
                                string     factorParameterName = factorEdge.ParameterName;
                                bool       isOutgoingMessage   = factorEdge.IsOutgoingMessage;
                                if (!arguments.ContainsKey(factorParameterName))
                                {
                                    if (direction == MessageDirection.Forwards)
                                    {
                                        Error(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it requires '" + parameter.Name +
                                              "'.  Please provide a separate Init method.");
                                    }
                                    fcninfo = null;
                                    break;
                                }
                                IExpression arg = arguments[factorParameterName];
                                args.Add(arg);
                            }
                        }
                        if (fcninfo != null)
                        {
                            IMethodInvokeExpression imie = Builder.StaticMethod(fcninfo.Method, args.ToArray());
                            //IExpression initExpr = MessageTransform.GetDistributionArrayCreateExpression(ivd.VariableType.DotNetType, target.GetExpressionType(), imie, vi);
                            IExpression initExpr = imie;
                            KeyValuePair <IVariableDeclaration, IExpression> key = new KeyValuePair <IVariableDeclaration, IExpression>(msgVar, factor);
                            messageInitExprs[key] = initExpr;
                        }
                    }
                    if (fcninfo != null)
                    {
                        messageVars[ivd] = msgVar;
                    }
                }
                if (fcninfo == null)
                {
                    if (direction == MessageDirection.Forwards)
                    {
                        continue;
                    }
                    //Console.WriteLine("creating backward message for "+target);
                    CreateBackwardMessageFromForward(target, factor);
                }
                IExpression msgExpr = GetMessageExpression(target, messageVars);
                arguments[parameterName] = msgExpr;
                Type inwardType = msgExpr.GetExpressionType();
                argumentTypes[parameterName] = inwardType;
            }
        }
Exemple #22
0
#pragma warning restore 162
#endif

        private void PostProcess()
        {
            // create a dependency graph between MethodInvokes
            Dictionary <IVariableDeclaration, List <int> > mutationsOfVariable = new Dictionary <IVariableDeclaration, List <NodeIndex> >();
            IndexedGraph g = new IndexedGraph(factorExprs.Count);

            foreach (NodeIndex node in g.Nodes)
            {
                IExpression factor = factorExprs[node];
                NodeInfo    info   = GetNodeInfo(factor);
                for (int i = 0; i < info.arguments.Count; i++)
                {
                    if (info.isReturnOrOut[i])
                    {
                        // this is a mutation.  add to mutationsOfVariable.
                        IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(info.arguments[i]);
                        if (ivd != null && CodeRecognizer.IsStochastic(context, ivd))
                        {
                            List <int> nodes;
                            if (!mutationsOfVariable.TryGetValue(ivd, out nodes))
                            {
                                nodes = new List <NodeIndex>();
                                mutationsOfVariable[ivd] = nodes;
                            }
                            nodes.Add(node);
                        }
                    }
                }
            }
            foreach (NodeIndex node in g.Nodes)
            {
                IExpression factor = factorExprs[node];
                NodeInfo    info   = GetNodeInfo(factor);
                for (int i = 0; i < info.arguments.Count; i++)
                {
                    if (!info.isReturnOrOut[i])
                    {
                        // not a mutation.  create a dependency on all mutations.
                        IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(info.arguments[i]);
                        if (ivd != null && CodeRecognizer.IsStochastic(context, ivd))
                        {
                            foreach (NodeIndex source in mutationsOfVariable[ivd])
                            {
                                g.AddEdge(source, node);
                            }
                        }
                    }
                }
            }
            List <NodeIndex>             topo_nodes = new List <NodeIndex>();
            DepthFirstSearch <NodeIndex> dfs        = new DepthFirstSearch <NodeIndex>(g.SourcesOf, g);

            dfs.FinishNode += delegate(NodeIndex node)
            {
                IExpression factor = factorExprs[node];
                ProcessFactor(factor, MessageDirection.Forwards);
                topo_nodes.Add(node);
            };
            // process nodes forward
            dfs.SearchFrom(g.Nodes);
            // process nodes backward
            for (int i = topo_nodes.Count - 1; i >= 0; i--)
            {
                NodeIndex   node   = topo_nodes[i];
                IExpression factor = factorExprs[node];
                ProcessFactor(factor, MessageDirection.Backwards);
            }
        }
Exemple #23
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);

            // Check if this is an index local variable
            if (baseVar == null)
            {
                return(expr);
            }
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return(expr);
            }
            if (cutVariables.Contains(baseVar))
            {
                return(expr);
            }

            // Get the repeat context for this variable
            RepeatContext lc = context.InputAttributes.Get <RepeatContext>(baseVar);

            if (lc == null)
            {
                Error("Repeat context not found for '" + baseVar.Name + "'.");
                return(expr);
            }

            // Get the reference loop context for this expression
            var rlc = lc.GetReferenceRepeatContext(context);

            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.repeats.Count == 0)
            {
                return(expr);
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get <Containers>(baseVar);

            for (int currentRepeat = 0; currentRepeat < rlc.repeatCounts.Count; currentRepeat++)
            {
                IExpression      repeatCount = rlc.repeatCounts[currentRepeat];
                IRepeatStatement repeat      = rlc.repeats[currentRepeat];

                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a repeat block.");
                    continue;
                }
                // are we replicating the argument of Gate.Cases?
                IMethodInvokeExpression imie = context.FindAncestor <IMethodInvokeExpression>();
                if (imie != null)
                {
                    if (Recognizer.IsStaticMethod(imie, typeof(Gate), "Cases"))
                    {
                        Error("'if(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'");
                    }
                    if (Recognizer.IsStaticMethod(imie, typeof(Gate), "CasesInt"))
                    {
                        Error("'case(" + expr + ")' or 'switch(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'");
                    }
                }

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IList <IStatement>  stmts   = Builder.StmtCollection();

                List <IList <IExpression> > inds   = Recognizer.GetIndices(expr);
                IVariableDeclaration        repVar = varInfo.DeriveIndexedVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rpt"), inds);
                if (!context.InputAttributes.Has <DerivedVariable>(repVar))
                {
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                }
                if (context.InputAttributes.Has <ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo         ci         = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }
                // set the RepeatContext of repVar to include all repeats up to this one (so that it doesn't get Entered again)
                List <IRepeatStatement> repeats = new List <IRepeatStatement>(lc.repeats);
                for (int i = 0; i <= currentRepeat; i++)
                {
                    repeats.Add(rlc.repeats[i]);
                }
                context.OutputAttributes.Remove <RepeatContext>(repVar);
                context.OutputAttributes.Set(repVar, new RepeatContext(repeats));

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression powerPlateMethod = Builder.StaticGenericMethod(
                    new Func <PlaceHolder, double, PlaceHolder>(PowerPlate.Enter),
                    new Type[] { returnType }, expr, repeatCount);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), powerPlateMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, powerPlateMethod);
                context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(baseVar, context.OutputAttributes, powerPlateMethod);
                context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(baseVar, context.OutputAttributes, powerPlateMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // remove inner repeats
                for (int i = currentRepeat + 1; i < rlc.repeatCounts.Count; i++)
                {
                    containers = containers.RemoveOneRepeat(rlc.repeats[i]);
                }
                context.OutputAttributes.Set(repVar, containers);
                containers = containers.RemoveOneRepeat(repeat);
                containers = Containers.RemoveUnusedLoops(containers, context, powerPlateMethod);
                if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar))
                {
                    containers = Containers.RemoveStochasticConditionals(containers, context);
                }
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                int        ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing  = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr    = Builder.VarRefExpr(repVar);
            }

            return(expr);
        }
        // copied from ReplicationTransform
        /// <summary>
        /// Returns true if the supplied expression is stochastic.
        /// </summary>
        /// <param name="expr"></param>
        /// <returns></returns>
        protected bool IsStochasticVariableReference(IExpression expr)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr);

            return((ivd != null) && CodeRecognizer.IsStochastic(context, ivd));
        }
Exemple #25
0
 protected override IExpression ConvertAssign(IAssignExpression iae)
 {
     foreach (IStatement stmt in context.FindAncestors <IStatement>())
     {
         // an initializer statement may perform a copy, but it is not valid to replace the lhs
         // in that case.
         if (context.InputAttributes.Has <Initializer>(stmt))
         {
             return(iae);
         }
     }
     // Look for assignments where the right hand side is a SetTo call
     if (iae.Expression is IMethodInvokeExpression imie)
     {
         bool isCopy                          = Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy));
         bool isSetTo                         = Recognizer.IsStaticGenericMethod(imie, typeof(ArrayHelper), "SetTo");
         bool isSetAllElementsTo              = Recognizer.IsStaticGenericMethod(imie, typeof(ArrayHelper), "SetAllElementsTo");
         bool isGetItemsPoint                 = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsPointOp <>), "ItemsAverageConditional");
         bool isGetJaggedItemsPoint           = Recognizer.IsStaticGenericMethod(imie, typeof(GetJaggedItemsPointOp <>), "ItemsAverageConditional");
         bool isGetDeepJaggedItemsPoint       = Recognizer.IsStaticGenericMethod(imie, typeof(GetDeepJaggedItemsPointOp <>), "ItemsAverageConditional");
         bool isGetItemsFromJaggedPoint       = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsFromJaggedPointOp <>), "ItemsAverageConditional");
         bool isGetItemsFromDeepJaggedPoint   = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsFromDeepJaggedPointOp <>), "ItemsAverageConditional");
         bool isGetJaggedItemsFromJaggedPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetJaggedItemsFromJaggedPointOp <>), "ItemsAverageConditional");
         if (isCopy || isSetTo || isSetAllElementsTo || isGetItemsPoint ||
             isGetJaggedItemsPoint || isGetDeepJaggedItemsPoint || isGetJaggedItemsFromJaggedPoint ||
             isGetItemsFromJaggedPoint || isGetItemsFromDeepJaggedPoint)
         {
             IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target);
             // Find the condition context
             var ifs         = context.FindAncestors <IConditionStatement>();
             var condContext = new List <IConditionStatement>();
             foreach (var ifSt in ifs)
             {
                 if (!CodeRecognizer.IsStochastic(context, ifSt.Condition))
                 {
                     condContext.Add(ifSt);
                 }
             }
             var         copyAttr = context.InputAttributes.GetOrCreate <CopyOfAttribute>(ivd, () => new CopyOfAttribute());
             IExpression rhs;
             if (isSetTo || isSetAllElementsTo)
             {
                 // Mark as copy of the second argument
                 rhs = imie.Arguments[1];
             }
             else
             {
                 // Mark as copy of the first argument
                 rhs = imie.Arguments[0];
             }
             InitialiseTo init = context.InputAttributes.Get <InitialiseTo>(ivd);
             if (init != null)
             {
                 IVariableDeclaration ivdRhs  = Recognizer.GetVariableDeclaration(rhs);
                 InitialiseTo         initRhs = (ivdRhs == null) ? null : context.InputAttributes.Get <InitialiseTo>(ivdRhs);
                 if (initRhs == null || !initRhs.initialMessagesExpression.Equals(init.initialMessagesExpression))
                 {
                     // Do not replace a variable with a unique initialiser
                     return(iae);
                 }
             }
             var initBack = context.InputAttributes.Get <InitialiseBackwardTo>(ivd);
             if (initBack != null && !(initBack.initialMessagesExpression is IArrayCreateExpression))
             {
                 IVariableDeclaration ivdRhs  = Recognizer.GetVariableDeclaration(rhs);
                 InitialiseBackwardTo initRhs = (ivdRhs == null) ? null : context.InputAttributes.Get <InitialiseBackwardTo>(ivdRhs);
                 if (initRhs == null || !initRhs.initialMessagesExpression.Equals(init.initialMessagesExpression))
                 {
                     // Do not replace a variable with a unique initialiser
                     return(iae);
                 }
             }
             if (isCopy || isSetTo)
             {
                 RemoveMatchingSuffixes(iae.Target, rhs, condContext, out IExpression lhsPrefix, out IExpression rhsPrefix);
                 copyAttr.copyMap[lhsPrefix] = new CopyOfAttribute.CopyContext {
                     Expression = rhsPrefix, ConditionContext = condContext
                 };
             }
             else if (isSetAllElementsTo)
             {
                 copyAttr.copiedInEveryElementMap[iae.Target] = new CopyOfAttribute.CopyContext {
                     Expression = rhs, ConditionContext = condContext
                 };
             }
             else if (isGetItemsPoint || isGetJaggedItemsPoint || isGetDeepJaggedItemsPoint || isGetItemsFromJaggedPoint || isGetItemsFromDeepJaggedPoint || isGetJaggedItemsFromJaggedPoint)
             {
                 var target     = ((IArrayIndexerExpression)iae.Target).Target;
                 int inputDepth = imie.Arguments.Count - 3;
                 List <IExpression> indexExprs = new List <IExpression>();
                 for (int i = 0; i < inputDepth; i++)
                 {
                     indexExprs.Add(imie.Arguments[1 + i]);
                 }
                 int outputDepth;
                 if (isGetDeepJaggedItemsPoint)
                 {
                     outputDepth = 3;
                 }
                 else if (isGetJaggedItemsPoint || isGetJaggedItemsFromJaggedPoint)
                 {
                     outputDepth = 2;
                 }
                 else
                 {
                     outputDepth = 1;
                 }
                 copyAttr.copyAtIndexMap[target] = new CopyOfAttribute.CopyContext2
                 {
                     Depth             = outputDepth,
                     ConditionContext  = condContext,
                     ExpressionAtIndex = (lhsIndices) =>
                     {
                         return(Builder.JaggedArrayIndex(rhs, indexExprs.ListSelect(indexExpr =>
                                                                                    new[] { Builder.JaggedArrayIndex(indexExpr, lhsIndices) })));
                     }
                 };
             }
             else
             {
                 throw new NotImplementedException();
             }
         }
     }
     return(iae);
 }
        protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie)
        {
            bool isDef = Recognizer.IsBeingMutated(context, iaie);

            if (isDef)
            {
                // do not clone the lhs of an array create assignment.
                IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>();
                if (assignExpr.Expression is IArrayCreateExpression)
                {
                    return(iaie);
                }
            }
            base.ConvertArrayIndexer(iaie);
            IndexInfo info;

            // TODO: Instead of storing an IndexInfo for each distinct expression, we should try to unify expressions, as in GateAnalysisTransform.
            // For example, we could unify a[0,i] and a[0,0] and use the same clone array for both.
            if (indexInfoOf.TryGetValue(iaie, out info))
            {
                Containers containers = new Containers(context);
                if (info.bindings.Count > 0)
                {
                    List <ConditionBinding> bindings = GetBindings(context, containers.inputs);
                    if (bindings.Count == 0)
                    {
                        info.bindings.Clear();
                    }
                    else
                    {
                        info.bindings.Add(bindings);
                    }
                }
                info.containers = Containers.Intersect(info.containers, containers);
                info.count++;
                if (isDef)
                {
                    info.IsAssignedTo = true;
                }
                return(iaie);
            }
            CheckIndicesAreNotStochastic(iaie.Indices);
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(iaie);

            // If not an indexed variable reference, skip it (e.g. an indexed argument reference)
            if (baseVar == null)
            {
                return(iaie);
            }
            // If the variable is not stochastic, skip it
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return(iaie);
            }
            // If the indices are all loop variables, skip it
            var  indices        = Recognizer.GetIndices(iaie);
            bool allLoopIndices = indices.All(bracket => bracket.All(indexExpr =>
            {
                if (indexExpr is IVariableReferenceExpression)
                {
                    IVariableReferenceExpression ivre = (IVariableReferenceExpression)indexExpr;
                    return(Recognizer.GetLoopForVariable(context, ivre) != null);
                }
                else
                {
                    return(false);
                }
            }));

            if (allLoopIndices)
            {
                return(iaie);
            }

            info            = new IndexInfo();
            info.containers = new Containers(context);
            List <ConditionBinding> bindings2 = GetBindings(context, info.containers.inputs);

            if (bindings2.Count > 0)
            {
                info.bindings.Add(bindings2);
            }
            info.count        = 1;
            info.IsAssignedTo = isDef;
            indexInfoOf[iaie] = info;
            return(iaie);
        }