Exemple #1
0
        /// <summary>
        /// When array creations are assigned to stochastic arrays, this creates corresponding arrays for the marginal and uses channels.
        /// </summary>
        /// <param name="iace"></param>
        /// <returns></returns>
        protected override IExpression ConvertArrayCreate(IArrayCreateExpression iace)
        {
            IAssignExpression iae = context.FindAncestor <IAssignExpression>();

            if (iae == null)
            {
                return(iace);
            }
            if (iae.Expression != iace)
            {
                return(iace);
            }
            IVariableDeclaration         ivd  = Recognizer.GetVariableDeclaration(iae.Target);
            VariableToChannelInformation vtci = context.InputAttributes.Get <VariableToChannelInformation>(ivd);

            if (vtci == null)
            {
                return(iace);              // not a stochastic variable
            }
            // Check if this is the last level of indexing
            bool lastLevel = (!(iace.Type is IArrayType));

            if ((lastLevel) && (vtci.usesEqualDefsStatements != null))
            {
                if (vtci.IsUsesEqualDefsStatementInserted)
                {
                    //Error("Duplicate array allocation.");
                }
                else
                {
                    // Insert the UsesEqualDef statement after the array is fully allocated.
                    // Note the array elements will not have been defined yet.
                    LoopContext    lc  = context.InputAttributes.Get <LoopContext>(ivd);
                    RefLoopContext rlc = lc.GetReferenceLoopContext(context);
                    // IMPORTANT TODO: add this statement at the right level!
                    IStatement ist = context.FindAncestor <IStatement>();
                    if (rlc.loops.Count > 0)
                    {
                        ist = rlc.loops[0];
                    }
                    int        ancIndex         = context.GetAncestorIndex(ist);
                    Containers containers       = context.InputAttributes.Get <Containers>(ivd);
                    Containers containersNeeded = containers.GetContainersNotInContext(context, ancIndex);
                    vtci.usesEqualDefsStatements = Containers.WrapWithContainers(vtci.usesEqualDefsStatements, containersNeeded.outputs);
                    context.AddStatementsAfter(ist, vtci.usesEqualDefsStatements);
                    vtci.IsUsesEqualDefsStatementInserted = true;
                }
            }
            return(iace);
        }
Exemple #2
0
        protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
        {
            if (CodeRecognizer.IsInfer(imie))
            {
                return(ConvertInfer(imie));
            }
            foreach (IExpression arg in imie.Arguments)
            {
                if (arg is IAddressOutExpression)
                {
                    IAddressOutExpression iaoe = (IAddressOutExpression)arg;
                    targetsOfCurrentAssignment.Add(iaoe.Expression);
                }
            }
            if (Recognizer.IsStaticGenericMethod(imie, new Func <IList <PlaceHolder>, int[][], PlaceHolder[][]>(Factor.JaggedSubarray)))
            {
                IExpression          arrayExpr = imie.Arguments[0];
                IVariableDeclaration ivd       = Recognizer.GetVariableDeclaration(imie.Arguments[0]);
                if (ivd != null && (arrayExpr is IVariableReferenceExpression) && this.variablesLackingVariableFactor.Contains(ivd) &&
                    !marginalOfVariable.ContainsKey(ivd))
                {
                    VariableInformation vi    = VariableInformation.GetVariableInformation(context, ivd);
                    IList <IStatement>  stmts = Builder.StmtCollection();
                    CreateMarginalChannel(ivd, vi, stmts);

                    Containers defContainers = context.InputAttributes.Get <Containers>(ivd);
                    int        ancIndex      = defContainers.GetMatchingAncestorIndex(context);
                    Containers missing       = defContainers.GetContainersNotInContext(context, ancIndex);
                    stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                    context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);

                    // none of the arguments should need to be transformed
                    IExpression             indicesExpr  = imie.Arguments[1];
                    IExpression             marginalExpr = Builder.VarRefExpr(marginalOfVariable[ivd]);
                    IMethodInvokeExpression mie          = Builder.StaticGenericMethod(new Models.FuncOut <IList <PlaceHolder>, int[][], IList <PlaceHolder>, PlaceHolder[][]>(Factor.JaggedSubarrayWithMarginal),
                                                                                       new Type[] { Utilities.Util.GetElementType(arrayExpr.GetExpressionType()) },
                                                                                       arrayExpr, indicesExpr, marginalExpr);
                    return(mie);
                }
            }
            return(base.ConvertMethodInvoke(imie));
        }
Exemple #3
0
        /// <summary>
        /// Adds a Replicate statement to the output code, if needed.
        /// </summary>
        /// <param name="target">LHS of assignment</param>
        /// <param name="rhs">RHS of assignment</param>
        /// <param name="shouldDelete">True if the original assignment statement should be deleted, i.e. it can be optimized away.</param>
        protected void AddReplicateStatement(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd);

            if (!vi.IsStochastic)
            {
                ProcessConstant(ivd);
                return;
            }
            bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(ivd);

            if (SwapIndices && replicateEvidenceVars != isEvidenceVar)
            {
                return;
            }
            // iae defines a stochastic variable
            ChannelAnalysisTransform.UsageInfo usageInfo;
            if (!analysis.usageInfo.TryGetValue(ivd, out usageInfo))
            {
                return;
            }
            int useCount = usageInfo.NumberOfUses;

            if (useCount <= 1)
            {
                return;
            }

            VariableToChannelInformation vtci;
            bool firstTime   = !usesOfVariable.TryGetValue(ivd, out vtci);
            int  targetDepth = Recognizer.GetIndexingDepth(target);
            int  minDepth    = usageInfo.indexingDepths[0];
            int  usageDepth  = minDepth;

            if (!SwapIndices)
            {
                usageDepth = 0;
            }
            Containers defContainers = context.InputAttributes.Get <Containers>(ivd);
            int        ancIndex      = defContainers.GetMatchingAncestorIndex(context);
            Containers missing       = defContainers.GetContainersNotInContext(context, ancIndex);

            if (firstTime)
            {
                // declaration of uses array
                IList <IStatement> stmts = Builder.StmtCollection();
                vtci  = DeclareUsesArray(stmts, ivd, vi, useCount, usageDepth);
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            // check that extra literal indices in the target are zero.
            // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i])
            // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1])
            // but not x_uses[i] = Rep(x[i]) since this will be a duplicate.
            bool extraLiteralsAreZero = CheckExtraLiteralsAreZero(target, targetDepth, usageDepth);

            if (extraLiteralsAreZero)
            {
                // definition of uses array
                IExpression       defExpr  = Builder.VarRefExpr(ivd);
                IExpression       usesExpr = Builder.VarRefExpr(vtci.usesDecl);
                List <IStatement> loops    = new List <IStatement>();
                if (usageDepth == targetDepth)
                {
                    defExpr = target;
                    if (defExpr is IVariableDeclarationExpression)
                    {
                        defExpr = Builder.VarRefExpr(ivd);
                    }
                    usesExpr = Builder.ReplaceVariable(defExpr, ivd, vtci.usesDecl);
                }
                else
                {
                    // loops over the last indexing bracket
                    for (int d = 0; d < usageDepth; d++)
                    {
                        List <IExpression> indices = new List <IExpression>();
                        for (int i = 0; i < vi.sizes[d].Length; i++)
                        {
                            IVariableDeclaration v    = vi.indexVars[d][i];
                            IStatement           loop = Builder.ForStmt(v, vi.sizes[d][i]);
                            loops.Add(loop);
                            indices.Add(Builder.VarRefExpr(v));
                        }
                        defExpr  = Builder.ArrayIndex(defExpr, indices);
                        usesExpr = Builder.ArrayIndex(usesExpr, indices);
                    }
                }

                if (rhs != null && rhs is IMethodInvokeExpression)
                {
                    IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                    bool copyPropagation         = false;
                    if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Factor.Copy)) && copyPropagation)
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = imie.Arguments[0];
                        shouldDelete = true;
                    }
                }

                // Add the statement:
                //   x_uses = Replicate(x, useCount)
                var genArgs = new Type[] { ivd.VariableType.DotNetType };
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func <PlaceHolder, int, PlaceHolder[]>(Factor.Replicate),
                    genArgs, defExpr, Builder.LiteralExpr(useCount));
                bool isGateExitRandom = context.InputAttributes.Has <Algorithms.VariationalMessagePassing.GateExitRandomVariable>(ivd);
                if (isGateExitRandom)
                {
                    repMethod = Builder.StaticGenericMethod(
                        new Func <PlaceHolder, int, PlaceHolder[]>(Gate.ReplicateExiting),
                        genArgs, defExpr, Builder.LiteralExpr(useCount));
                }
                context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, repMethod);
                if (context.InputAttributes.Has <DivideMessages>(ivd))
                {
                    context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, repMethod);
                }
                else if (useCount == 2)
                {
                    // division has no benefit for 2 uses, and degrades the schedule
                    context.OutputAttributes.Set(repMethod, new DivideMessages(false));
                }
                context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, repMethod);
                IStatement repSt = Builder.AssignStmt(usesExpr, repMethod);
                if (usageDepth == targetDepth)
                {
                    if (isEvidenceVar)
                    {
                        // place the Replicate after the current statement, but outside of any evidence conditionals
                        ancIndex = context.FindAncestorIndex <IStatement>();
                        for (int i = ancIndex - 2; i > 0; i--)
                        {
                            object ancestor = context.GetAncestor(i);
                            if (ancestor is IConditionStatement)
                            {
                                IConditionStatement ics = (IConditionStatement)ancestor;
                                if (CodeRecognizer.IsStochastic(context, ics.Condition))
                                {
                                    ancIndex = i;
                                    break;
                                }
                            }
                        }
                        context.AddStatementAfterAncestorIndex(ancIndex, repSt);
                    }
                    else
                    {
                        context.AddStatementAfterCurrent(repSt);
                    }
                }
                else if (firstTime)
                {
                    // place Replicate after assignment but outside of definition loops (can't have any uses there)
                    repSt = Containers.WrapWithContainers(repSt, loops);
                    repSt = Containers.WrapWithContainers(repSt, missing.outputs);
                    context.AddStatementAfterAncestorIndex(ancIndex, repSt);
                }
            }
        }
        /// <summary>
        /// This method does all the work of converting literal indexing expressions.
        /// </summary>
        /// <param name="iaie"></param>
        /// <returns></returns>
        protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie)
        {
            IndexAnalysisTransform.IndexInfo info;
            if (!analysis.indexInfoOf.TryGetValue(iaie, out info))
            {
                return(base.ConvertArrayIndexer(iaie));
            }
            // Determine if this is a definition i.e. the variable is on the left hand side of an assignment
            // This must be done before base.ConvertArrayIndexer changes the expression!
            bool isDef = Recognizer.IsBeingMutated(context, iaie);

            if (info.clone != null)
            {
                if (isDef)
                {
                    // check that extra literal indices in the target are zero.
                    // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i])
                    // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1])
                    // but not x_uses[i] = Rep(x[i]) since this will be a duplicate.
                    bool   extraLiteralsAreZero = true;
                    int    parentIndex          = context.InputStack.Count - 2;
                    object parent = context.GetAncestor(parentIndex);
                    while (parent is IArrayIndexerExpression)
                    {
                        IArrayIndexerExpression parent_iaie = (IArrayIndexerExpression)parent;
                        foreach (IExpression index in parent_iaie.Indices)
                        {
                            if (index is ILiteralExpression)
                            {
                                int value = (int)((ILiteralExpression)index).Value;
                                if (value != 0)
                                {
                                    extraLiteralsAreZero = false;
                                    break;
                                }
                            }
                        }
                        parentIndex--;
                        parent = context.GetAncestor(parentIndex);
                    }
                    if (false && extraLiteralsAreZero)
                    {
                        // change:
                        //   array[0] = f()
                        // into:
                        //   array_item0 = f()
                        //   array[0] = Copy(array_item0)
                        IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { iaie.GetExpressionType() },
                                                                       info.clone);
                        IStatement copySt = Builder.AssignStmt(iaie, copy);
                        context.AddStatementAfterCurrent(copySt);
                    }
                }
                return(info.clone);
            }

            if (isDef)
            {
                // do not clone the lhs of an array create assignment.
                IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>();
                if (assignExpr.Expression is IArrayCreateExpression)
                {
                    return(iaie);
                }
            }

            IVariableDeclaration originalBaseVar = Recognizer.GetVariableDeclaration(iaie);

            // If the variable is not stochastic, return
            if (!CodeRecognizer.IsStochastic(context, originalBaseVar))
            {
                return(iaie);
            }

            IExpression          newExpr    = null;
            IVariableDeclaration baseVar    = originalBaseVar;
            IVariableDeclaration newvd      = null;
            IExpression          rhsExpr    = null;
            Containers           containers = info.containers;
            Type tp = iaie.GetExpressionType();

            if (tp == null)
            {
                Error("Could not determine type of expression: " + iaie);
                return(iaie);
            }
            var stmts      = Builder.StmtCollection();
            var stmtsAfter = Builder.StmtCollection();

            // does the expression have the form array[indices[k]][indices2[k]][indices3[k]]?
            if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
                iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
            {
                IArrayIndexerExpression index3 = (IArrayIndexerExpression)iaie.Indices[0];
                IArrayIndexerExpression iaie2  = (IArrayIndexerExpression)iaie.Target;
                if (index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression &&
                    iaie2.Target is IArrayIndexerExpression &&
                    iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression)
                {
                    IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie2.Indices[0];
                    IArrayIndexerExpression iaie3  = (IArrayIndexerExpression)iaie2.Target;
                    if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
                        iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression)
                    {
                        IArrayIndexerExpression      index      = (IArrayIndexerExpression)iaie3.Indices[0];
                        IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
                        IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                        if (index.Indices.Count == 1 && index2.Indices[0].Equals(innerIndex) &&
                            index3.Indices[0].Equals(innerIndex) &&
                            innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
                        {
                            // expression has the form array[indices[k]][indices2[k]][indices3[k]]
                            if (isDef)
                            {
                                Error("fancy indexing not allowed on left hand side");
                                return(iaie);
                            }
                            WarnIfLocal(index.Target, iaie3.Target, iaie);
                            WarnIfLocal(index2.Target, iaie3.Target, iaie);
                            WarnIfLocal(index3.Target, iaie3.Target, iaie);
                            containers = RemoveReferencesTo(containers, innerIndex);
                            IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
                            var         indices  = Recognizer.GetIndices(iaie);
                            // Build name of replacement variable from index values
                            StringBuilder sb = new StringBuilder("_item");
                            AppendIndexString(sb, iaie3);
                            AppendIndexString(sb, iaie2);
                            AppendIndexString(sb, iaie);
                            string name = ToString(iaie3.Target) + sb.ToString();
                            VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                            newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
                            if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                            {
                                context.InputAttributes.Set(newvd, new DerivedVariable());
                            }
                            IExpression getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
                                                                               new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
                            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                            stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                            newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
                            rhsExpr = getItems;
                        }
                    }
                }
            }
            // does the expression have the form array[indices[k]][indices2[k]]?
            if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
                iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
            {
                IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie.Indices[0];
                IArrayIndexerExpression target = (IArrayIndexerExpression)iaie.Target;
                if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
                    target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression)
                {
                    IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index2.Indices[0];
                    IArrayIndexerExpression      index      = (IArrayIndexerExpression)target.Indices[0];
                    IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                    if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) &&
                        innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target))
                    {
                        // expression has the form array[indices[k]][indices2[k]]
                        if (isDef)
                        {
                            Error("fancy indexing not allowed on left hand side");
                            return(iaie);
                        }
                        var innerLoops = new List <IForStatement>();
                        innerLoops.Add(innerLoop);
                        var indexTarget  = index.Target;
                        var index2Target = index2.Target;
                        // check if the index array is jagged, i.e. array[indices[k][j]]
                        while (indexTarget is IArrayIndexerExpression && index2Target is IArrayIndexerExpression)
                        {
                            IArrayIndexerExpression indexTargetExpr  = (IArrayIndexerExpression)indexTarget;
                            IArrayIndexerExpression index2TargetExpr = (IArrayIndexerExpression)index2Target;
                            if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression &&
                                index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression)
                            {
                                IVariableReferenceExpression innerIndexTarget  = (IVariableReferenceExpression)indexTargetExpr.Indices[0];
                                IVariableReferenceExpression innerIndex2Target = (IVariableReferenceExpression)index2TargetExpr.Indices[0];
                                IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget);
                                if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
                                    innerIndexTarget.Equals(innerIndex2Target))
                                {
                                    innerLoops.Add(indexTargetLoop);
                                    indexTarget  = indexTargetExpr.Target;
                                    index2Target = index2TargetExpr.Target;
                                }
                                else
                                {
                                    break;
                                }
                            }
                            else
                            {
                                break;
                            }
                        }
                        WarnIfLocal(indexTarget, target.Target, iaie);
                        WarnIfLocal(index2Target, target.Target, iaie);
                        innerLoops.Reverse();
                        var loopSizes    = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
                        var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
                        // Build name of replacement variable from index values
                        StringBuilder sb = new StringBuilder("_item");
                        AppendIndexString(sb, target);
                        AppendIndexString(sb, iaie);
                        string name = ToString(target.Target) + sb.ToString();
                        VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                        var indices = Recognizer.GetIndices(iaie);
                        newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
                        if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                        {
                            context.InputAttributes.Set(newvd, new DerivedVariable());
                        }
                        IExpression getItems;
                        if (innerLoops.Count == 1)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else if (innerLoops.Count == 2)
                        {
                            getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged),
                                                                   new Type[] { tp }, target.Target, indexTarget, index2Target);
                        }
                        else
                        {
                            throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
                        }
                        context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                        stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                        var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
                        newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
                        rhsExpr = getItems;
                    }
                    else if (HasAnyCommonLoops(index, index2))
                    {
                        Warning($"This model will consume excess memory due to the indexing expression {iaie} since {index} and {index2} have larger depth than the compiler can handle.");
                    }
                }
            }
            if (newvd == null)
            {
                IArrayIndexerExpression originalExpr = iaie;
                if (UseGetItems)
                {
                    iaie = (IArrayIndexerExpression)base.ConvertArrayIndexer(iaie);
                }
                if (!object.ReferenceEquals(iaie.Target, originalExpr.Target) && false)
                {
                    // TODO: determine if this warning is useful or not
                    string warningText = "This model may consume excess memory due to the jagged indexing expression {0}";
                    Warning(string.Format(warningText, originalExpr));
                }

                // get the baseVar of the new expression.
                baseVar = Recognizer.GetVariableDeclaration(iaie);
                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);

                var indices = Recognizer.GetIndices(iaie);
                // Build name of replacement variable from index values
                StringBuilder sb = new StringBuilder("_item");
                AppendIndexString(sb, iaie);
                string name = ToString(iaie.Target) + sb.ToString();

                // does the expression have the form array[indices[k]]?
                if (UseGetItems && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
                {
                    IArrayIndexerExpression index = (IArrayIndexerExpression)iaie.Indices[0];
                    if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression)
                    {
                        // expression has the form array[indices[k]]
                        IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
                        IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                        if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie.Target, index.Target))
                        {
                            if (isDef)
                            {
                                Error("fancy indexing not allowed on left hand side");
                                return(iaie);
                            }
                            var innerLoops = new List <IForStatement>();
                            innerLoops.Add(innerLoop);
                            var indexTarget = index.Target;
                            // check if the index array is jagged, i.e. array[indices[k][j]]
                            while (indexTarget is IArrayIndexerExpression)
                            {
                                IArrayIndexerExpression index2 = (IArrayIndexerExpression)indexTarget;
                                if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression)
                                {
                                    IVariableReferenceExpression innerIndex2 = (IVariableReferenceExpression)index2.Indices[0];
                                    IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2);
                                    if (innerLoop2 != null && AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target))
                                    {
                                        innerLoops.Add(innerLoop2);
                                        indexTarget = index2.Target;
                                        // This limit must match the number of handled cases below.
                                        if (innerLoops.Count == 3)
                                        {
                                            break;
                                        }
                                    }
                                    else
                                    {
                                        break;
                                    }
                                }
                                else
                                {
                                    break;
                                }
                            }
                            WarnIfLocal(indexTarget, iaie.Target, originalExpr);
                            innerLoops.Reverse();
                            var loopSizes    = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
                            var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
                            newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
                            if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                            {
                                context.InputAttributes.Set(newvd, new DerivedVariable());
                            }
                            IExpression getItems;
                            if (innerLoops.Count == 1)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else if (innerLoops.Count == 2)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else if (innerLoops.Count == 3)
                            {
                                getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems),
                                                                       new Type[] { tp }, iaie.Target, indexTarget);
                            }
                            else
                            {
                                throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
                            }
                            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems);
                            stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
                            var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
                            newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
                            rhsExpr = getItems;
                        }
                    }
                }
                if (newvd == null)
                {
                    if (UseGetItems && info.count < 2)
                    {
                        return(iaie);
                    }
                    try
                    {
                        newvd = varInfo.DeriveIndexedVariable(stmts, context, name, indices, copyInitializer: isDef);
                    }
                    catch (Exception ex)
                    {
                        Error(ex.Message, ex);
                        return(iaie);
                    }
                    context.OutputAttributes.Remove <DerivedVariable>(newvd);
                    newExpr = Builder.VarRefExpr(newvd);
                    rhsExpr = iaie;
                    if (isDef)
                    {
                        // change:
                        //   array[0] = f()
                        // into:
                        //   array_item0 = f()
                        //   array[0] = Copy(array_item0)
                        IExpression copy   = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, newExpr);
                        IStatement  copySt = Builder.AssignStmt(iaie, copy);
                        stmtsAfter.Add(copySt);
                        if (!context.InputAttributes.Has <DerivedVariable>(baseVar))
                        {
                            context.InputAttributes.Set(baseVar, new DerivedVariable());
                        }
                    }
                    else if (!info.IsAssignedTo)
                    {
                        // change:
                        //   x = f(array[0])
                        // into:
                        //   array_item0 = Copy(array[0])
                        //   x = f(array_item0)
                        IExpression copy   = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, iaie);
                        IStatement  copySt = Builder.AssignStmt(Builder.VarRefExpr(newvd), copy);
                        //if (attr != null) context.OutputAttributes.Set(copySt, attr);
                        stmts.Add(copySt);
                        context.InputAttributes.Set(newvd, new DerivedVariable());
                    }
                }
            }

            // Reduce memory consumption by declaring the clone outside of unnecessary loops.
            // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item.
            containers = Containers.RemoveUnusedLoops(containers, context, rhsExpr);
            if (context.InputAttributes.Has <DoNotSendEvidence>(originalBaseVar))
            {
                containers = Containers.RemoveStochasticConditionals(containers, context);
            }
            if (true)
            {
                IStatement st = GetBindingSetContainer(FilterBindingSet(info.bindings,
                                                                        binding => Containers.ContainsExpression(containers.inputs, context, binding.GetExpression())));
                if (st != null)
                {
                    containers.Add(st);
                }
            }
            // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible,
            // then wrap the declaration with the remaining containers.
            int        ancIndex = containers.GetMatchingAncestorIndex(context);
            Containers missing  = containers.GetContainersNotInContext(context, ancIndex);

            stmts = Containers.WrapWithContainers(stmts, missing.outputs);
            context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            stmtsAfter = Containers.WrapWithContainers(stmtsAfter, missing.outputs);
            context.AddStatementsAfterAncestorIndex(ancIndex, stmtsAfter);
            context.InputAttributes.Set(newvd, containers);
            info.clone = newExpr;
            return(newExpr);
        }
        protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
        {
            IExpression result = base.ConvertMethodInvoke(imie);

            if (result is IMethodInvokeExpression)
            {
                imie = (IMethodInvokeExpression)result;
            }
            else
            {
                return(result);
            }
            if (UseJaggedSubarray && Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, IReadOnlyList <PlaceHolder> >(Collection.Subarray)))
            {
                // check for the form Subarray(arrayExpr, indices[i]) where arrayExpr does not depend on i
                IExpression arrayExpr = imie.Arguments[0];
                IExpression arg1      = imie.Arguments[1];
                if (arg1 is IArrayIndexerExpression)
                {
                    IArrayIndexerExpression index = (IArrayIndexerExpression)arg1;
                    if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression)
                    {
                        // index has the form indices[i]
                        List <IStatement> targetLoops = Containers.GetLoopsNeededForExpression(context, arrayExpr, -1, false);
                        List <IStatement> indexLoops  = Containers.GetLoopsNeededForExpression(context, index.Target, -1, false);
                        Set <IStatement>  parentLoops = new Set <IStatement>();
                        parentLoops.AddRange(targetLoops);
                        parentLoops.AddRange(indexLoops);
                        IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
                        IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
                        foreach (IStatement loop in parentLoops)
                        {
                            if (Containers.ContainersAreEqual(loop, innerLoop))
                            {
                                // arrayExpr depends on i
                                return(imie);
                            }
                        }
                        IVariableDeclaration arrayVar = Recognizer.GetVariableDeclaration(arrayExpr);
                        // If the variable is not stochastic, return
                        if (arrayVar == null)
                        {
                            return(imie);
                        }
                        VariableInformation arrayInfo = VariableInformation.GetVariableInformation(context, arrayVar);
                        if (!arrayInfo.IsStochastic)
                        {
                            return(imie);
                        }
                        object indexVar = Recognizer.GetDeclaration(index);
                        VariableInformation indexInfo = VariableInformation.GetVariableInformation(context, indexVar);
                        int         depth             = Recognizer.GetIndexingDepth(index);
                        IExpression resultSize        = indexInfo.sizes[depth][0];
                        var         indices           = Recognizer.GetIndices(index);
                        int         replaceCount      = 0;
                        resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, ref replaceCount);
                        indexInfo.DefineIndexVarsUpToDepth(context, depth + 1);
                        IVariableDeclaration resultIndex = indexInfo.indexVars[depth][0];
                        Type arrayType   = arrayExpr.GetExpressionType();
                        Type elementType = Util.GetElementType(arrayType);

                        // create a new variable arrayExpr_indices = JaggedSubarray(arrayExpr, indices)
                        string name = ToString(arrayExpr) + "_" + ToString(index.Target);

                        var stmts        = Builder.StmtCollection();
                        var arrayIndices = Recognizer.GetIndices(arrayExpr);
                        var bracket      = Builder.ExprCollection();
                        bracket.Add(Builder.ArrayIndex(index, Builder.VarRefExpr(resultIndex)));
                        arrayIndices.Add(bracket);
                        IExpression          loopSize = Recognizer.LoopSizeExpression(innerLoop);
                        IVariableDeclaration temp     = arrayInfo.DeriveArrayVariable(stmts, context, name, resultSize, resultIndex, arrayIndices);
                        VariableInformation  tempInfo = VariableInformation.GetVariableInformation(context, temp);
                        stmts.Clear();
                        IVariableDeclaration newvd = tempInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex));
                        if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                        {
                            context.InputAttributes.Set(newvd, new DerivedVariable());
                        }
                        IExpression rhs = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, int[][], PlaceHolder[][]>(Collection.JaggedSubarray),
                                                                      new Type[] { elementType }, arrayExpr, index.Target);
                        context.InputAttributes.CopyObjectAttributesTo <Algorithm>(newvd, context.OutputAttributes, rhs);
                        stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), rhs));

                        // Reduce memory consumption by declaring the clone outside of unnecessary loops.
                        // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item.
                        Containers containers = new Containers(context);
                        containers = RemoveReferencesTo(containers, innerIndex);
                        containers = Containers.RemoveUnusedLoops(containers, context, rhs);
                        if (context.InputAttributes.Has <DoNotSendEvidence>(arrayVar))
                        {
                            containers = Containers.RemoveStochasticConditionals(containers, context);
                        }
                        // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible,
                        // then wrap the declaration with the remaining containers.
                        int        ancIndex = containers.GetMatchingAncestorIndex(context);
                        Containers missing  = containers.GetContainersNotInContext(context, ancIndex);
                        stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                        context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
                        context.InputAttributes.Set(newvd, containers);

                        // convert into arrayExpr_indices[i]
                        IExpression newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
                        newExpr = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { newExpr.GetExpressionType() }, newExpr);
                        return(newExpr);
                    }
                }
            }
            return(imie);
        }
Exemple #6
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);
            // Check if this is an index local variable
            if (baseVar == null) return expr;
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr;

            // Get the loop context for this variable
            LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar);
            if (lc == null)
            {
                Error("Loop context not found for '" + baseVar.Name + "'.");
                return expr;
            }

            // Get the reference loop context for this expression
            RefLoopContext rlc = lc.GetReferenceLoopContext(context);
            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.loops.Count == 0) return expr;

            // the set of loop variables that are constant wrt the expr
            Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>();
            constantLoopVars.AddRange(lc.loopVariables);

            // collect set of all loop variable indices in the expression
            Set<int> embeddedLoopIndices = new Set<int>();
            List<IList<IExpression>> brackets = Recognizer.GetIndices(expr);
            foreach (IList<IExpression> bracket in brackets)
            {
                foreach (IExpression index in bracket)
                {
                    IExpression indExpr = index;
                    if (indExpr is IBinaryExpression ibe)
                    {
                        indExpr = ibe.Left;
                    }
                    IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr);
                    if (indVar != null)
                    {
                        if (!constantLoopVars.Contains(indVar))
                        {
                            int loopIndex = rlc.loopVariables.IndexOf(indVar);
                            if (loopIndex != -1)
                            {
                                // indVar is a loop variable
                                constantLoopVars.Add(rlc.loopVariables[loopIndex]);
                            }
                            else 
                            {
                                // indVar is not a loop variable
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar);
                                foreach (var ivd in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                    else
                    {
                        foreach(var ivd in Recognizer.GetVariables(indExpr))
                        {
                            if (!constantLoopVars.Contains(ivd))
                            {
                                // copied from above
                                LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd);
                                foreach (var ivd2 in lc2.loopVariables)
                                {
                                    if (!constantLoopVars.Contains(ivd2))
                                    {
                                        int loopIndex2 = rlc.loopVariables.IndexOf(ivd2);
                                        if (loopIndex2 != -1)
                                            embeddedLoopIndices.Add(loopIndex2);
                                        else
                                            Error($"Index {ivd2} is not in {rlc} for expression {expr}");
                                    }
                                }
                            }
                        }
                    }
                }
            }

            // Find loop variables that must be constant due to condition statements.
            List<IStatement> ancestors = context.FindAncestors<IStatement>();
            foreach (IStatement ancestor in ancestors)
            {
                if (!(ancestor is IConditionStatement ics))
                    continue;
                ConditionBinding binding = new ConditionBinding(ics.Condition);
                IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs);
                IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs);
                int index = rlc.loopVariables.IndexOf(ivd);
                if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars))
                {
                    constantLoopVars.Add(ivd);
                    continue;
                }
                int index2 = rlc.loopVariables.IndexOf(ivd2);
                if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars))
                {
                    constantLoopVars.Add(ivd2);
                    continue;
                }
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get<Containers>(baseVar);

            IExpression originalExpr = expr;

            for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++)
            {
                IVariableDeclaration loopVar = rlc.loopVariables[currentLoop];
                if (constantLoopVars.Contains(loopVar))
                    continue;
                IForStatement loop = rlc.loops[currentLoop];
                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a loop.  Variables on the left hand side of an assignment must be indexed by all containing loops.");
                    continue;
                }
                if (embeddedLoopIndices.Contains(currentLoop))
                {
                    string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
                    Warning(string.Format(warningText, originalExpr, loopVar.Name));
                }
                // split expr into a target and extra indices, where target will be replicated and extra indices will be added later
                var extraIndices = new List<IEnumerable<IExpression>>();
                AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate);

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IExpression loopSize = Recognizer.LoopSizeExpression(loop);
                IList<IStatement> stmts = Builder.StmtCollection();
                List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate);
                IVariableDeclaration newIndexVar = loopVar;
                // if loopVar is already an indexVar of varInfo, create a new variable
                if (varInfo.HasIndexVar(loopVar))
                {
                    newIndexVar = VariableInformation.GenerateLoopVar(context, "_a");
                    context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar);
                }
                IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"),
                                                                          loopSize, newIndexVar, inds, useArrays: true);
                if (!context.InputAttributes.Has<DerivedVariable>(repVar))
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                if (context.InputAttributes.Has<ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression repMethod = Builder.StaticGenericMethod(
                    new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate),
                    new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod);
                context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // RemoveUnusedLoops will also remove conditionals involving those loop variables.
                // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms.
                containers = Containers.RemoveUnusedLoops(containers, context, repMethod);
                if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context);
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                if (containers.Contains(loop))
                {
                    Error("Internal: invalid containers for replicating " + baseVar);
                    break;
                }
                int ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                context.OutputAttributes.Set(repVar, containers);
                List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex);
                foreach (IStatement container in missing.inputs)
                {
                    if (container is IForStatement ifs) loops.Add(ifs);
                }
                context.OutputAttributes.Set(repVar, new LoopContext(loops));
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar));
                expr = Builder.JaggedArrayIndex(expr, extraIndices);
            }

            return expr;
        }
Exemple #7
0
        protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete)
        {
            IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target);

            if (ivd == null)
            {
                return;
            }
            if (rhs is IArrayCreateExpression)
            {
                IArrayCreateExpression iace = (IArrayCreateExpression)rhs;
                bool zeroLength             = iace.Dimensions.All(dimExpr =>
                                                                  (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0));
                if (!zeroLength && iace.Initializer == null)
                {
                    return; // variable will have assignments to elements
                }
            }
            bool firstTime = !variablesAssigned.Contains(ivd);

            variablesAssigned.Add(ivd);
            bool isInferred   = context.InputAttributes.Has <IsInferred>(ivd);
            bool isStochastic = CodeRecognizer.IsStochastic(context, ivd);

            if (!isStochastic)
            {
                return;
            }
            VariableInformation vi            = VariableInformation.GetVariableInformation(context, ivd);
            Containers          defContainers = context.InputAttributes.Get <Containers>(ivd);
            int        ancIndex = defContainers.GetMatchingAncestorIndex(context);
            Containers missing  = defContainers.GetContainersNotInContext(context, ancIndex);
            // definition of a stochastic variable
            IExpression lhs = target;

            if (lhs is IVariableDeclarationExpression)
            {
                lhs = Builder.VarRefExpr(ivd);
            }
            IExpression defExpr = lhs;

            if (firstTime && isStochastic)
            {
                // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform
                ChannelInfo defChannel = ChannelInfo.DefChannel(vi);
                defChannel.decl = ivd;
                context.OutputAttributes.Set(ivd, defChannel);
            }
            bool       isDerived = context.InputAttributes.Has <DerivedVariable>(ivd);
            IAlgorithm algorithm = this.algorithmDefault;
            Algorithm  algAttr   = context.InputAttributes.Get <Algorithm>(ivd);

            if (algAttr != null)
            {
                algorithm = algAttr.algorithm;
            }
            if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime)
            {
                vi.DefineAllIndexVars(context);
                IList <IStatement>   stmts     = Builder.StmtCollection();
                IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv");
                context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl));
                ChannelInfo derivChannel = ChannelInfo.DefChannel(vi);
                derivChannel.decl = derivDecl;
                context.OutputAttributes.Set(derivChannel.decl, derivChannel);
                context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'"));
                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd);

            if (this.analysis.variablesExcludingVariableFactor.Contains(ivd))
            {
                this.variablesLackingVariableFactor.Add(ivd);
                // ivd will get a marginal channel in ConvertMethodInvoke
                useOfVariable[ivd] = ivd;
                return;
            }
            if (isDerived && !isInferred && !isPointEstimate)
            {
                return;
            }

            IExpression useExpr2 = null;

            if (firstTime)
            {
                // create marginal and use channels
                vi.DefineAllIndexVars(context);
                IList <IStatement> stmts = Builder.StmtCollection();

                CreateMarginalChannel(ivd, vi, stmts);
                if (isStochastic)
                {
                    CreateUseChannel(ivd, vi, stmts);
                    context.InputAttributes.Set(useOfVariable[ivd], defContainers);
                }

                // Add the declarations
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
            }
            if (isStochastic && !useOfVariable.ContainsKey(ivd))
            {
                Error("cannot find use channel of " + ivd);
                return;
            }
            IExpression  marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]);
            IExpression  useExpr      = isStochastic ? Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]) : marginalExpr;
            InitialiseTo it           = context.InputAttributes.Get <InitialiseTo>(ivd);

            Type[] genArgs = new Type[] { defExpr.GetExpressionType() };
            if (rhs is IMethodInvokeExpression)
            {
                IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs;
                if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)) && ancIndex < context.InputStack.Count - 2)
                {
                    IExpression          arg  = imie.Arguments[0];
                    IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg);
                    if (ivd2 != null && context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2))
                    {
                        // if a variable is a copy, use the original expression since it will give more precise dependencies.
                        defExpr      = arg;
                        shouldDelete = true;
                        bool makeClone = false;
                        if (makeClone)
                        {
                            VariableInformation         vi2      = VariableInformation.GetVariableInformation(context, ivd2);
                            IList <IStatement>          stmts    = Builder.StmtCollection();
                            List <IList <IExpression> > indices  = Recognizer.GetIndices(defExpr);
                            IVariableDeclaration        useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices);
                            useExpr2 = Builder.VarRefExpr(useDecl2);
                            Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2);
                            int        ancIndex2      = defContainers2.GetMatchingAncestorIndex(context);
                            Containers missing2       = defContainers2.GetContainersNotInContext(context, ancIndex2);
                            stmts = Containers.WrapWithContainers(stmts, missing2.outputs);
                            context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts);
                            context.InputAttributes.Set(useDecl2, defContainers2);

                            // TODO: call CreateUseChannel
                            ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2);
                            usageChannel.decl = useDecl2;
                            context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2);
                            context.OutputAttributes.Set(useDecl2, usageChannel);
                            //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'"));
                            context.OutputAttributes.Remove <InitialiseTo>(vi.declaration);

                            IExpression copyExpr = Builder.StaticGenericMethod(
                                new Func <PlaceHolder, PlaceHolder>(Clone.Copy), genArgs, useExpr2);
                            var copyStmt = Builder.AssignStmt(useExpr, copyExpr);
                            context.AddStatementAfterCurrent(copyStmt);
                        }
                    }
                }
            }

            // Add the variable factor
            IExpression variableFactorExpr;
            bool        isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd);

            if (isGateExitRandom)
            {
                variableFactorExpr = Builder.StaticGenericMethod(
                    new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable),
                    genArgs, defExpr, marginalExpr);
            }
            else
            {
                Delegate d = algorithm.GetVariableFactor(isDerived, it != null);
                if (isPointEstimate)
                {
                    d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint);
                }
                if (it == null)
                {
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr);
                }
                else
                {
                    IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression);
                    variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr);
                }
            }
            context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr);
            context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr);
            if (isStochastic)
            {
                context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor());
            }
            var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr);

            context.AddStatementAfterCurrent(assignStmt);
        }
Exemple #8
0
        protected IExpression ConvertWithReplication(IExpression expr)
        {
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);

            // Check if this is an index local variable
            if (baseVar == null)
            {
                return(expr);
            }
            // Check if the variable is stochastic
            if (!CodeRecognizer.IsStochastic(context, baseVar))
            {
                return(expr);
            }
            if (cutVariables.Contains(baseVar))
            {
                return(expr);
            }

            // Get the repeat context for this variable
            RepeatContext lc = context.InputAttributes.Get <RepeatContext>(baseVar);

            if (lc == null)
            {
                Error("Repeat context not found for '" + baseVar.Name + "'.");
                return(expr);
            }

            // Get the reference loop context for this expression
            var rlc = lc.GetReferenceRepeatContext(context);

            // If the reference is in the same loop context as the declaration, do nothing.
            if (rlc.repeats.Count == 0)
            {
                return(expr);
            }

            // Determine if this expression is being defined (is on the LHS of an assignment)
            bool isDef = Recognizer.IsBeingMutated(context, expr);

            Containers containers = context.InputAttributes.Get <Containers>(baseVar);

            for (int currentRepeat = 0; currentRepeat < rlc.repeatCounts.Count; currentRepeat++)
            {
                IExpression      repeatCount = rlc.repeatCounts[currentRepeat];
                IRepeatStatement repeat      = rlc.repeats[currentRepeat];

                // must replicate across this loop.
                if (isDef)
                {
                    Error("Cannot re-define a variable in a repeat block.");
                    continue;
                }
                // are we replicating the argument of Gate.Cases?
                IMethodInvokeExpression imie = context.FindAncestor <IMethodInvokeExpression>();
                if (imie != null)
                {
                    if (Recognizer.IsStaticMethod(imie, typeof(Gate), "Cases"))
                    {
                        Error("'if(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'");
                    }
                    if (Recognizer.IsStaticMethod(imie, typeof(Gate), "CasesInt"))
                    {
                        Error("'case(" + expr + ")' or 'switch(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'");
                    }
                }

                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                IList <IStatement>  stmts   = Builder.StmtCollection();

                List <IList <IExpression> > inds   = Recognizer.GetIndices(expr);
                IVariableDeclaration        repVar = varInfo.DeriveIndexedVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rpt"), inds);
                if (!context.InputAttributes.Has <DerivedVariable>(repVar))
                {
                    context.OutputAttributes.Set(repVar, new DerivedVariable());
                }
                if (context.InputAttributes.Has <ChannelInfo>(baseVar))
                {
                    VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar);
                    ChannelInfo         ci         = ChannelInfo.UseChannel(repVarInfo);
                    ci.decl = repVar;
                    context.OutputAttributes.Set(repVar, ci);
                }
                // set the RepeatContext of repVar to include all repeats up to this one (so that it doesn't get Entered again)
                List <IRepeatStatement> repeats = new List <IRepeatStatement>(lc.repeats);
                for (int i = 0; i <= currentRepeat; i++)
                {
                    repeats.Add(rlc.repeats[i]);
                }
                context.OutputAttributes.Remove <RepeatContext>(repVar);
                context.OutputAttributes.Set(repVar, new RepeatContext(repeats));

                // Create replicate factor
                Type returnType = Builder.ToType(repVar.VariableType);
                IMethodInvokeExpression powerPlateMethod = Builder.StaticGenericMethod(
                    new Func <PlaceHolder, double, PlaceHolder>(PowerPlate.Enter),
                    new Type[] { returnType }, expr, repeatCount);

                IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), powerPlateMethod);
                // Copy attributes across from variable to replication expression
                context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, powerPlateMethod);
                context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(baseVar, context.OutputAttributes, powerPlateMethod);
                context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(baseVar, context.OutputAttributes, powerPlateMethod);
                stmts.Add(Builder.ExprStatement(assignExpression));

                // add any containers missing from context.
                containers = new Containers(context);
                // remove inner repeats
                for (int i = currentRepeat + 1; i < rlc.repeatCounts.Count; i++)
                {
                    containers = containers.RemoveOneRepeat(rlc.repeats[i]);
                }
                context.OutputAttributes.Set(repVar, containers);
                containers = containers.RemoveOneRepeat(repeat);
                containers = Containers.RemoveUnusedLoops(containers, context, powerPlateMethod);
                if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar))
                {
                    containers = Containers.RemoveStochasticConditionals(containers, context);
                }
                //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count);
                //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; }
                int        ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing  = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.inputs);
                // must convert the output since it may contain 'if' conditions
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true);
                baseVar = repVar;
                expr    = Builder.VarRefExpr(repVar);
            }

            return(expr);
        }
        private IExpression GetClone(IExpression expr)
        {
            // Determine if this is a definition i.e. the variable is on the left hand side of an assignment
            // This must be done before base.ConvertArrayIndexer changes the expression!
            if (Recognizer.IsBeingIndexed(context))
            {
                return(expr);
            }
            IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr);

            // If not an indexed variable reference, skip it (e.g. an indexed argument reference)
            if (baseVar == null)
            {
                return(expr);
            }
            DepthInfo depthInfo;

            if (!analysis.depthInfos.TryGetValue(baseVar, out depthInfo))
            {
                return(expr);
            }
            if (depthInfo.useCount <= 1 || depthInfo.indexInfoOfDepth.Count == 1)
            {
                return(expr);
            }
            bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(baseVar);

            if (isEvidenceVar && !cloneEvidenceVars)
            {
                return(expr);
            }
            int       depth = Recognizer.GetIndexingDepth(expr);
            IndexInfo info  = depthInfo.indexInfoOfDepth[depth];

            if (info.clone == null)
            {
                if (depth == depthInfo.definitionDepth)
                {
                    return(expr);
                }
                VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
                string             name     = baseVar.Name + "_depth" + depth;
                IList <IStatement> stmts    = Builder.StmtCollection();

                // declare the clone
                varInfo.DefineAllIndexVars(context);
                IVariableDeclaration newvd = varInfo.DeriveIndexedVariable(stmts, context, name);
                VariableInformation  newVariableInformation = VariableInformation.GetVariableInformation(context, newvd);
                newVariableInformation.LiteralIndexingDepth = info.literalIndexingDepth;
                if (!context.InputAttributes.Has <DerivedVariable>(newvd))
                {
                    context.InputAttributes.Set(newvd, new DerivedVariable());
                }
                if (context.InputAttributes.Has <ChannelInfo>(baseVar))
                {
                    ChannelInfo ci = ChannelInfo.UseChannel(varInfo);
                    ci.decl = newvd;
                    context.OutputAttributes.Set(newvd, ci);
                }

                // define the clone
                // if depth < definitionDepth, index by the definitionDepth
                // else index by depth
                // e.g.
                // x[i] = definition
                // x_depth0[i] = Copy(x[i])
                // x_depth2[i][j] = Copy(x[i][j])
                int         indexingDepth = System.Math.Max(depth, depthInfo.definitionDepth);
                IExpression lhs           = Builder.VarRefExpr(newvd);
                // TODO: clone the next lower depth, not always the baseVar
                IExpression rhs = Builder.VarRefExpr(baseVar);
                AddCopyStatements(stmts, newVariableInformation, indexingDepth, lhs, rhs);

                // Reduce memory consumption by declaring the clone outside of unnecessary loops.
                // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item.
                Containers containers = info.containers;
                containers = Containers.RemoveUnusedLoops(containers, context, Builder.VarRefExpr(baseVar));
                if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar))
                {
                    containers = Containers.RemoveStochasticConditionals(containers, context);
                }
                // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible,
                // then wrap the declaration with the remaining containers.
                int        ancIndex = containers.GetMatchingAncestorIndex(context);
                Containers missing  = containers.GetContainersNotInContext(context, ancIndex);
                stmts = Containers.WrapWithContainers(stmts, missing.outputs);
                context.AddStatementsBeforeAncestorIndex(ancIndex, stmts);
                context.InputAttributes.Set(newvd, containers);
                info.clone = Builder.VarRefExpr(newvd);
            }
            List <IList <IExpression> > indices = Recognizer.GetIndices(expr);

            return(Builder.JaggedArrayIndex(info.clone, indices));
        }