/// <summary> /// When array creations are assigned to stochastic arrays, this creates corresponding arrays for the marginal and uses channels. /// </summary> /// <param name="iace"></param> /// <returns></returns> protected override IExpression ConvertArrayCreate(IArrayCreateExpression iace) { IAssignExpression iae = context.FindAncestor <IAssignExpression>(); if (iae == null) { return(iace); } if (iae.Expression != iace) { return(iace); } IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target); VariableToChannelInformation vtci = context.InputAttributes.Get <VariableToChannelInformation>(ivd); if (vtci == null) { return(iace); // not a stochastic variable } // Check if this is the last level of indexing bool lastLevel = (!(iace.Type is IArrayType)); if ((lastLevel) && (vtci.usesEqualDefsStatements != null)) { if (vtci.IsUsesEqualDefsStatementInserted) { //Error("Duplicate array allocation."); } else { // Insert the UsesEqualDef statement after the array is fully allocated. // Note the array elements will not have been defined yet. LoopContext lc = context.InputAttributes.Get <LoopContext>(ivd); RefLoopContext rlc = lc.GetReferenceLoopContext(context); // IMPORTANT TODO: add this statement at the right level! IStatement ist = context.FindAncestor <IStatement>(); if (rlc.loops.Count > 0) { ist = rlc.loops[0]; } int ancIndex = context.GetAncestorIndex(ist); Containers containers = context.InputAttributes.Get <Containers>(ivd); Containers containersNeeded = containers.GetContainersNotInContext(context, ancIndex); vtci.usesEqualDefsStatements = Containers.WrapWithContainers(vtci.usesEqualDefsStatements, containersNeeded.outputs); context.AddStatementsAfter(ist, vtci.usesEqualDefsStatements); vtci.IsUsesEqualDefsStatementInserted = true; } } return(iace); }
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie) { if (CodeRecognizer.IsInfer(imie)) { return(ConvertInfer(imie)); } foreach (IExpression arg in imie.Arguments) { if (arg is IAddressOutExpression) { IAddressOutExpression iaoe = (IAddressOutExpression)arg; targetsOfCurrentAssignment.Add(iaoe.Expression); } } if (Recognizer.IsStaticGenericMethod(imie, new Func <IList <PlaceHolder>, int[][], PlaceHolder[][]>(Factor.JaggedSubarray))) { IExpression arrayExpr = imie.Arguments[0]; IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(imie.Arguments[0]); if (ivd != null && (arrayExpr is IVariableReferenceExpression) && this.variablesLackingVariableFactor.Contains(ivd) && !marginalOfVariable.ContainsKey(ivd)) { VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); IList <IStatement> stmts = Builder.StmtCollection(); CreateMarginalChannel(ivd, vi, stmts); Containers defContainers = context.InputAttributes.Get <Containers>(ivd); int ancIndex = defContainers.GetMatchingAncestorIndex(context); Containers missing = defContainers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); // none of the arguments should need to be transformed IExpression indicesExpr = imie.Arguments[1]; IExpression marginalExpr = Builder.VarRefExpr(marginalOfVariable[ivd]); IMethodInvokeExpression mie = Builder.StaticGenericMethod(new Models.FuncOut <IList <PlaceHolder>, int[][], IList <PlaceHolder>, PlaceHolder[][]>(Factor.JaggedSubarrayWithMarginal), new Type[] { Utilities.Util.GetElementType(arrayExpr.GetExpressionType()) }, arrayExpr, indicesExpr, marginalExpr); return(mie); } } return(base.ConvertMethodInvoke(imie)); }
/// <summary> /// Adds a Replicate statement to the output code, if needed. /// </summary> /// <param name="target">LHS of assignment</param> /// <param name="rhs">RHS of assignment</param> /// <param name="shouldDelete">True if the original assignment statement should be deleted, i.e. it can be optimized away.</param> protected void AddReplicateStatement(IExpression target, IExpression rhs, ref bool shouldDelete) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd == null) { return; } VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); if (!vi.IsStochastic) { ProcessConstant(ivd); return; } bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(ivd); if (SwapIndices && replicateEvidenceVars != isEvidenceVar) { return; } // iae defines a stochastic variable ChannelAnalysisTransform.UsageInfo usageInfo; if (!analysis.usageInfo.TryGetValue(ivd, out usageInfo)) { return; } int useCount = usageInfo.NumberOfUses; if (useCount <= 1) { return; } VariableToChannelInformation vtci; bool firstTime = !usesOfVariable.TryGetValue(ivd, out vtci); int targetDepth = Recognizer.GetIndexingDepth(target); int minDepth = usageInfo.indexingDepths[0]; int usageDepth = minDepth; if (!SwapIndices) { usageDepth = 0; } Containers defContainers = context.InputAttributes.Get <Containers>(ivd); int ancIndex = defContainers.GetMatchingAncestorIndex(context); Containers missing = defContainers.GetContainersNotInContext(context, ancIndex); if (firstTime) { // declaration of uses array IList <IStatement> stmts = Builder.StmtCollection(); vtci = DeclareUsesArray(stmts, ivd, vi, useCount, usageDepth); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } // check that extra literal indices in the target are zero. // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i]) // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1]) // but not x_uses[i] = Rep(x[i]) since this will be a duplicate. bool extraLiteralsAreZero = CheckExtraLiteralsAreZero(target, targetDepth, usageDepth); if (extraLiteralsAreZero) { // definition of uses array IExpression defExpr = Builder.VarRefExpr(ivd); IExpression usesExpr = Builder.VarRefExpr(vtci.usesDecl); List <IStatement> loops = new List <IStatement>(); if (usageDepth == targetDepth) { defExpr = target; if (defExpr is IVariableDeclarationExpression) { defExpr = Builder.VarRefExpr(ivd); } usesExpr = Builder.ReplaceVariable(defExpr, ivd, vtci.usesDecl); } else { // loops over the last indexing bracket for (int d = 0; d < usageDepth; d++) { List <IExpression> indices = new List <IExpression>(); for (int i = 0; i < vi.sizes[d].Length; i++) { IVariableDeclaration v = vi.indexVars[d][i]; IStatement loop = Builder.ForStmt(v, vi.sizes[d][i]); loops.Add(loop); indices.Add(Builder.VarRefExpr(v)); } defExpr = Builder.ArrayIndex(defExpr, indices); usesExpr = Builder.ArrayIndex(usesExpr, indices); } } if (rhs != null && rhs is IMethodInvokeExpression) { IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs; bool copyPropagation = false; if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Factor.Copy)) && copyPropagation) { // if a variable is a copy, use the original expression since it will give more precise dependencies. defExpr = imie.Arguments[0]; shouldDelete = true; } } // Add the statement: // x_uses = Replicate(x, useCount) var genArgs = new Type[] { ivd.VariableType.DotNetType }; IMethodInvokeExpression repMethod = Builder.StaticGenericMethod( new Func <PlaceHolder, int, PlaceHolder[]>(Factor.Replicate), genArgs, defExpr, Builder.LiteralExpr(useCount)); bool isGateExitRandom = context.InputAttributes.Has <Algorithms.VariationalMessagePassing.GateExitRandomVariable>(ivd); if (isGateExitRandom) { repMethod = Builder.StaticGenericMethod( new Func <PlaceHolder, int, PlaceHolder[]>(Gate.ReplicateExiting), genArgs, defExpr, Builder.LiteralExpr(useCount)); } context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, repMethod); if (context.InputAttributes.Has <DivideMessages>(ivd)) { context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, repMethod); } else if (useCount == 2) { // division has no benefit for 2 uses, and degrades the schedule context.OutputAttributes.Set(repMethod, new DivideMessages(false)); } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, repMethod); IStatement repSt = Builder.AssignStmt(usesExpr, repMethod); if (usageDepth == targetDepth) { if (isEvidenceVar) { // place the Replicate after the current statement, but outside of any evidence conditionals ancIndex = context.FindAncestorIndex <IStatement>(); for (int i = ancIndex - 2; i > 0; i--) { object ancestor = context.GetAncestor(i); if (ancestor is IConditionStatement) { IConditionStatement ics = (IConditionStatement)ancestor; if (CodeRecognizer.IsStochastic(context, ics.Condition)) { ancIndex = i; break; } } } context.AddStatementAfterAncestorIndex(ancIndex, repSt); } else { context.AddStatementAfterCurrent(repSt); } } else if (firstTime) { // place Replicate after assignment but outside of definition loops (can't have any uses there) repSt = Containers.WrapWithContainers(repSt, loops); repSt = Containers.WrapWithContainers(repSt, missing.outputs); context.AddStatementAfterAncestorIndex(ancIndex, repSt); } } }
/// <summary> /// This method does all the work of converting literal indexing expressions. /// </summary> /// <param name="iaie"></param> /// <returns></returns> protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie) { IndexAnalysisTransform.IndexInfo info; if (!analysis.indexInfoOf.TryGetValue(iaie, out info)) { return(base.ConvertArrayIndexer(iaie)); } // Determine if this is a definition i.e. the variable is on the left hand side of an assignment // This must be done before base.ConvertArrayIndexer changes the expression! bool isDef = Recognizer.IsBeingMutated(context, iaie); if (info.clone != null) { if (isDef) { // check that extra literal indices in the target are zero. // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i]) // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1]) // but not x_uses[i] = Rep(x[i]) since this will be a duplicate. bool extraLiteralsAreZero = true; int parentIndex = context.InputStack.Count - 2; object parent = context.GetAncestor(parentIndex); while (parent is IArrayIndexerExpression) { IArrayIndexerExpression parent_iaie = (IArrayIndexerExpression)parent; foreach (IExpression index in parent_iaie.Indices) { if (index is ILiteralExpression) { int value = (int)((ILiteralExpression)index).Value; if (value != 0) { extraLiteralsAreZero = false; break; } } } parentIndex--; parent = context.GetAncestor(parentIndex); } if (false && extraLiteralsAreZero) { // change: // array[0] = f() // into: // array_item0 = f() // array[0] = Copy(array_item0) IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { iaie.GetExpressionType() }, info.clone); IStatement copySt = Builder.AssignStmt(iaie, copy); context.AddStatementAfterCurrent(copySt); } } return(info.clone); } if (isDef) { // do not clone the lhs of an array create assignment. IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>(); if (assignExpr.Expression is IArrayCreateExpression) { return(iaie); } } IVariableDeclaration originalBaseVar = Recognizer.GetVariableDeclaration(iaie); // If the variable is not stochastic, return if (!CodeRecognizer.IsStochastic(context, originalBaseVar)) { return(iaie); } IExpression newExpr = null; IVariableDeclaration baseVar = originalBaseVar; IVariableDeclaration newvd = null; IExpression rhsExpr = null; Containers containers = info.containers; Type tp = iaie.GetExpressionType(); if (tp == null) { Error("Could not determine type of expression: " + iaie); return(iaie); } var stmts = Builder.StmtCollection(); var stmtsAfter = Builder.StmtCollection(); // does the expression have the form array[indices[k]][indices2[k]][indices3[k]]? if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index3 = (IArrayIndexerExpression)iaie.Indices[0]; IArrayIndexerExpression iaie2 = (IArrayIndexerExpression)iaie.Target; if (index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression && iaie2.Target is IArrayIndexerExpression && iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie2.Indices[0]; IArrayIndexerExpression iaie3 = (IArrayIndexerExpression)iaie2.Target; if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression && iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index = (IArrayIndexerExpression)iaie3.Indices[0]; IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0]; IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (index.Indices.Count == 1 && index2.Indices[0].Equals(innerIndex) && index3.Indices[0].Equals(innerIndex) && innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target)) { // expression has the form array[indices[k]][indices2[k]][indices3[k]] if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } WarnIfLocal(index.Target, iaie3.Target, iaie); WarnIfLocal(index2.Target, iaie3.Target, iaie); WarnIfLocal(index3.Target, iaie3.Target, iaie); containers = RemoveReferencesTo(containers, innerIndex); IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop); var indices = Recognizer.GetIndices(iaie); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, iaie3); AppendIndexString(sb, iaie2); AppendIndexString(sb, iaie); string name = ToString(iaie3.Target) + sb.ToString(); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged), new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target); context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex); rhsExpr = getItems; } } } } // does the expression have the form array[indices[k]][indices2[k]]? if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie.Indices[0]; IArrayIndexerExpression target = (IArrayIndexerExpression)iaie.Target; if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression && target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression) { IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index2.Indices[0]; IArrayIndexerExpression index = (IArrayIndexerExpression)target.Indices[0]; IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) && innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target)) { // expression has the form array[indices[k]][indices2[k]] if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } var innerLoops = new List <IForStatement>(); innerLoops.Add(innerLoop); var indexTarget = index.Target; var index2Target = index2.Target; // check if the index array is jagged, i.e. array[indices[k][j]] while (indexTarget is IArrayIndexerExpression && index2Target is IArrayIndexerExpression) { IArrayIndexerExpression indexTargetExpr = (IArrayIndexerExpression)indexTarget; IArrayIndexerExpression index2TargetExpr = (IArrayIndexerExpression)index2Target; if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression && index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression) { IVariableReferenceExpression innerIndexTarget = (IVariableReferenceExpression)indexTargetExpr.Indices[0]; IVariableReferenceExpression innerIndex2Target = (IVariableReferenceExpression)index2TargetExpr.Indices[0]; IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget); if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) && innerIndexTarget.Equals(innerIndex2Target)) { innerLoops.Add(indexTargetLoop); indexTarget = indexTargetExpr.Target; index2Target = index2TargetExpr.Target; } else { break; } } else { break; } } WarnIfLocal(indexTarget, target.Target, iaie); WarnIfLocal(index2Target, target.Target, iaie); innerLoops.Reverse(); var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) }); var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) }); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, target); AppendIndexString(sb, iaie); string name = ToString(target.Target) + sb.ToString(); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); var indices = Recognizer.GetIndices(iaie); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems; if (innerLoops.Count == 1) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged), new Type[] { tp }, target.Target, indexTarget, index2Target); } else if (innerLoops.Count == 2) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged), new Type[] { tp }, target.Target, indexTarget, index2Target); } else { throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}"); } context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i]))); newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices); rhsExpr = getItems; } else if (HasAnyCommonLoops(index, index2)) { Warning($"This model will consume excess memory due to the indexing expression {iaie} since {index} and {index2} have larger depth than the compiler can handle."); } } } if (newvd == null) { IArrayIndexerExpression originalExpr = iaie; if (UseGetItems) { iaie = (IArrayIndexerExpression)base.ConvertArrayIndexer(iaie); } if (!object.ReferenceEquals(iaie.Target, originalExpr.Target) && false) { // TODO: determine if this warning is useful or not string warningText = "This model may consume excess memory due to the jagged indexing expression {0}"; Warning(string.Format(warningText, originalExpr)); } // get the baseVar of the new expression. baseVar = Recognizer.GetVariableDeclaration(iaie); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); var indices = Recognizer.GetIndices(iaie); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, iaie); string name = ToString(iaie.Target) + sb.ToString(); // does the expression have the form array[indices[k]]? if (UseGetItems && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index = (IArrayIndexerExpression)iaie.Indices[0]; if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression) { // expression has the form array[indices[k]] IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0]; IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie.Target, index.Target)) { if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } var innerLoops = new List <IForStatement>(); innerLoops.Add(innerLoop); var indexTarget = index.Target; // check if the index array is jagged, i.e. array[indices[k][j]] while (indexTarget is IArrayIndexerExpression) { IArrayIndexerExpression index2 = (IArrayIndexerExpression)indexTarget; if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression) { IVariableReferenceExpression innerIndex2 = (IVariableReferenceExpression)index2.Indices[0]; IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2); if (innerLoop2 != null && AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target)) { innerLoops.Add(innerLoop2); indexTarget = index2.Target; // This limit must match the number of handled cases below. if (innerLoops.Count == 3) { break; } } else { break; } } else { break; } } WarnIfLocal(indexTarget, iaie.Target, originalExpr); innerLoops.Reverse(); var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) }); var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) }); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems; if (innerLoops.Count == 1) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems), new Type[] { tp }, iaie.Target, indexTarget); } else if (innerLoops.Count == 2) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems), new Type[] { tp }, iaie.Target, indexTarget); } else if (innerLoops.Count == 3) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems), new Type[] { tp }, iaie.Target, indexTarget); } else { throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}"); } context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i]))); newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices); rhsExpr = getItems; } } } if (newvd == null) { if (UseGetItems && info.count < 2) { return(iaie); } try { newvd = varInfo.DeriveIndexedVariable(stmts, context, name, indices, copyInitializer: isDef); } catch (Exception ex) { Error(ex.Message, ex); return(iaie); } context.OutputAttributes.Remove <DerivedVariable>(newvd); newExpr = Builder.VarRefExpr(newvd); rhsExpr = iaie; if (isDef) { // change: // array[0] = f() // into: // array_item0 = f() // array[0] = Copy(array_item0) IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, newExpr); IStatement copySt = Builder.AssignStmt(iaie, copy); stmtsAfter.Add(copySt); if (!context.InputAttributes.Has <DerivedVariable>(baseVar)) { context.InputAttributes.Set(baseVar, new DerivedVariable()); } } else if (!info.IsAssignedTo) { // change: // x = f(array[0]) // into: // array_item0 = Copy(array[0]) // x = f(array_item0) IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, iaie); IStatement copySt = Builder.AssignStmt(Builder.VarRefExpr(newvd), copy); //if (attr != null) context.OutputAttributes.Set(copySt, attr); stmts.Add(copySt); context.InputAttributes.Set(newvd, new DerivedVariable()); } } } // Reduce memory consumption by declaring the clone outside of unnecessary loops. // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item. containers = Containers.RemoveUnusedLoops(containers, context, rhsExpr); if (context.InputAttributes.Has <DoNotSendEvidence>(originalBaseVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } if (true) { IStatement st = GetBindingSetContainer(FilterBindingSet(info.bindings, binding => Containers.ContainsExpression(containers.inputs, context, binding.GetExpression()))); if (st != null) { containers.Add(st); } } // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible, // then wrap the declaration with the remaining containers. int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); stmtsAfter = Containers.WrapWithContainers(stmtsAfter, missing.outputs); context.AddStatementsAfterAncestorIndex(ancIndex, stmtsAfter); context.InputAttributes.Set(newvd, containers); info.clone = newExpr; return(newExpr); }
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie) { IExpression result = base.ConvertMethodInvoke(imie); if (result is IMethodInvokeExpression) { imie = (IMethodInvokeExpression)result; } else { return(result); } if (UseJaggedSubarray && Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, IReadOnlyList <PlaceHolder> >(Collection.Subarray))) { // check for the form Subarray(arrayExpr, indices[i]) where arrayExpr does not depend on i IExpression arrayExpr = imie.Arguments[0]; IExpression arg1 = imie.Arguments[1]; if (arg1 is IArrayIndexerExpression) { IArrayIndexerExpression index = (IArrayIndexerExpression)arg1; if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression) { // index has the form indices[i] List <IStatement> targetLoops = Containers.GetLoopsNeededForExpression(context, arrayExpr, -1, false); List <IStatement> indexLoops = Containers.GetLoopsNeededForExpression(context, index.Target, -1, false); Set <IStatement> parentLoops = new Set <IStatement>(); parentLoops.AddRange(targetLoops); parentLoops.AddRange(indexLoops); IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0]; IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); foreach (IStatement loop in parentLoops) { if (Containers.ContainersAreEqual(loop, innerLoop)) { // arrayExpr depends on i return(imie); } } IVariableDeclaration arrayVar = Recognizer.GetVariableDeclaration(arrayExpr); // If the variable is not stochastic, return if (arrayVar == null) { return(imie); } VariableInformation arrayInfo = VariableInformation.GetVariableInformation(context, arrayVar); if (!arrayInfo.IsStochastic) { return(imie); } object indexVar = Recognizer.GetDeclaration(index); VariableInformation indexInfo = VariableInformation.GetVariableInformation(context, indexVar); int depth = Recognizer.GetIndexingDepth(index); IExpression resultSize = indexInfo.sizes[depth][0]; var indices = Recognizer.GetIndices(index); int replaceCount = 0; resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, ref replaceCount); indexInfo.DefineIndexVarsUpToDepth(context, depth + 1); IVariableDeclaration resultIndex = indexInfo.indexVars[depth][0]; Type arrayType = arrayExpr.GetExpressionType(); Type elementType = Util.GetElementType(arrayType); // create a new variable arrayExpr_indices = JaggedSubarray(arrayExpr, indices) string name = ToString(arrayExpr) + "_" + ToString(index.Target); var stmts = Builder.StmtCollection(); var arrayIndices = Recognizer.GetIndices(arrayExpr); var bracket = Builder.ExprCollection(); bracket.Add(Builder.ArrayIndex(index, Builder.VarRefExpr(resultIndex))); arrayIndices.Add(bracket); IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop); IVariableDeclaration temp = arrayInfo.DeriveArrayVariable(stmts, context, name, resultSize, resultIndex, arrayIndices); VariableInformation tempInfo = VariableInformation.GetVariableInformation(context, temp); stmts.Clear(); IVariableDeclaration newvd = tempInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex)); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression rhs = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, int[][], PlaceHolder[][]>(Collection.JaggedSubarray), new Type[] { elementType }, arrayExpr, index.Target); context.InputAttributes.CopyObjectAttributesTo <Algorithm>(newvd, context.OutputAttributes, rhs); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), rhs)); // Reduce memory consumption by declaring the clone outside of unnecessary loops. // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item. Containers containers = new Containers(context); containers = RemoveReferencesTo(containers, innerIndex); containers = Containers.RemoveUnusedLoops(containers, context, rhs); if (context.InputAttributes.Has <DoNotSendEvidence>(arrayVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible, // then wrap the declaration with the remaining containers. int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); context.InputAttributes.Set(newvd, containers); // convert into arrayExpr_indices[i] IExpression newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex); newExpr = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { newExpr.GetExpressionType() }, newExpr); return(newExpr); } } } return(imie); }
protected IExpression ConvertWithReplication(IExpression expr) { IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // Check if this is an index local variable if (baseVar == null) return expr; // Check if the variable is stochastic if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr; // Get the loop context for this variable LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar); if (lc == null) { Error("Loop context not found for '" + baseVar.Name + "'."); return expr; } // Get the reference loop context for this expression RefLoopContext rlc = lc.GetReferenceLoopContext(context); // If the reference is in the same loop context as the declaration, do nothing. if (rlc.loops.Count == 0) return expr; // the set of loop variables that are constant wrt the expr Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>(); constantLoopVars.AddRange(lc.loopVariables); // collect set of all loop variable indices in the expression Set<int> embeddedLoopIndices = new Set<int>(); List<IList<IExpression>> brackets = Recognizer.GetIndices(expr); foreach (IList<IExpression> bracket in brackets) { foreach (IExpression index in bracket) { IExpression indExpr = index; if (indExpr is IBinaryExpression ibe) { indExpr = ibe.Left; } IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr); if (indVar != null) { if (!constantLoopVars.Contains(indVar)) { int loopIndex = rlc.loopVariables.IndexOf(indVar); if (loopIndex != -1) { // indVar is a loop variable constantLoopVars.Add(rlc.loopVariables[loopIndex]); } else { // indVar is not a loop variable LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar); foreach (var ivd in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd} is not in {rlc} for expression {expr}"); } } } } } else { foreach(var ivd in Recognizer.GetVariables(indExpr)) { if (!constantLoopVars.Contains(ivd)) { // copied from above LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd); foreach (var ivd2 in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd2)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd2); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd2} is not in {rlc} for expression {expr}"); } } } } } } } // Find loop variables that must be constant due to condition statements. List<IStatement> ancestors = context.FindAncestors<IStatement>(); foreach (IStatement ancestor in ancestors) { if (!(ancestor is IConditionStatement ics)) continue; ConditionBinding binding = new ConditionBinding(ics.Condition); IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs); IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs); int index = rlc.loopVariables.IndexOf(ivd); if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars)) { constantLoopVars.Add(ivd); continue; } int index2 = rlc.loopVariables.IndexOf(ivd2); if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars)) { constantLoopVars.Add(ivd2); continue; } } // Determine if this expression is being defined (is on the LHS of an assignment) bool isDef = Recognizer.IsBeingMutated(context, expr); Containers containers = context.InputAttributes.Get<Containers>(baseVar); IExpression originalExpr = expr; for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++) { IVariableDeclaration loopVar = rlc.loopVariables[currentLoop]; if (constantLoopVars.Contains(loopVar)) continue; IForStatement loop = rlc.loops[currentLoop]; // must replicate across this loop. if (isDef) { Error("Cannot re-define a variable in a loop. Variables on the left hand side of an assignment must be indexed by all containing loops."); continue; } if (embeddedLoopIndices.Contains(currentLoop)) { string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays."; Warning(string.Format(warningText, originalExpr, loopVar.Name)); } // split expr into a target and extra indices, where target will be replicated and extra indices will be added later var extraIndices = new List<IEnumerable<IExpression>>(); AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); IExpression loopSize = Recognizer.LoopSizeExpression(loop); IList<IStatement> stmts = Builder.StmtCollection(); List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate); IVariableDeclaration newIndexVar = loopVar; // if loopVar is already an indexVar of varInfo, create a new variable if (varInfo.HasIndexVar(loopVar)) { newIndexVar = VariableInformation.GenerateLoopVar(context, "_a"); context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar); } IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"), loopSize, newIndexVar, inds, useArrays: true); if (!context.InputAttributes.Has<DerivedVariable>(repVar)) context.OutputAttributes.Set(repVar, new DerivedVariable()); if (context.InputAttributes.Has<ChannelInfo>(baseVar)) { VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar); ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo); ci.decl = repVar; context.OutputAttributes.Set(repVar, ci); } // Create replicate factor Type returnType = Builder.ToType(repVar.VariableType); IMethodInvokeExpression repMethod = Builder.StaticGenericMethod( new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate), new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize); IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod); // Copy attributes across from variable to replication expression context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod); stmts.Add(Builder.ExprStatement(assignExpression)); // add any containers missing from context. containers = new Containers(context); // RemoveUnusedLoops will also remove conditionals involving those loop variables. // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms. containers = Containers.RemoveUnusedLoops(containers, context, repMethod); if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context); //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count); //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; } if (containers.Contains(loop)) { Error("Internal: invalid containers for replicating " + baseVar); break; } int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.inputs); context.OutputAttributes.Set(repVar, containers); List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex); foreach (IStatement container in missing.inputs) { if (container is IForStatement ifs) loops.Add(ifs); } context.OutputAttributes.Set(repVar, new LoopContext(loops)); // must convert the output since it may contain 'if' conditions context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true); baseVar = repVar; expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar)); expr = Builder.JaggedArrayIndex(expr, extraIndices); } return expr; }
protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd == null) { return; } if (rhs is IArrayCreateExpression) { IArrayCreateExpression iace = (IArrayCreateExpression)rhs; bool zeroLength = iace.Dimensions.All(dimExpr => (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0)); if (!zeroLength && iace.Initializer == null) { return; // variable will have assignments to elements } } bool firstTime = !variablesAssigned.Contains(ivd); variablesAssigned.Add(ivd); bool isInferred = context.InputAttributes.Has <IsInferred>(ivd); bool isStochastic = CodeRecognizer.IsStochastic(context, ivd); if (!isStochastic) { return; } VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); Containers defContainers = context.InputAttributes.Get <Containers>(ivd); int ancIndex = defContainers.GetMatchingAncestorIndex(context); Containers missing = defContainers.GetContainersNotInContext(context, ancIndex); // definition of a stochastic variable IExpression lhs = target; if (lhs is IVariableDeclarationExpression) { lhs = Builder.VarRefExpr(ivd); } IExpression defExpr = lhs; if (firstTime && isStochastic) { // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform ChannelInfo defChannel = ChannelInfo.DefChannel(vi); defChannel.decl = ivd; context.OutputAttributes.Set(ivd, defChannel); } bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd); IAlgorithm algorithm = this.algorithmDefault; Algorithm algAttr = context.InputAttributes.Get <Algorithm>(ivd); if (algAttr != null) { algorithm = algAttr.algorithm; } if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime) { vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv"); context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl)); ChannelInfo derivChannel = ChannelInfo.DefChannel(vi); derivChannel.decl = derivDecl; context.OutputAttributes.Set(derivChannel.decl, derivChannel); context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'")); // Add the declarations stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd); if (this.analysis.variablesExcludingVariableFactor.Contains(ivd)) { this.variablesLackingVariableFactor.Add(ivd); // ivd will get a marginal channel in ConvertMethodInvoke useOfVariable[ivd] = ivd; return; } if (isDerived && !isInferred && !isPointEstimate) { return; } IExpression useExpr2 = null; if (firstTime) { // create marginal and use channels vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); CreateMarginalChannel(ivd, vi, stmts); if (isStochastic) { CreateUseChannel(ivd, vi, stmts); context.InputAttributes.Set(useOfVariable[ivd], defContainers); } // Add the declarations stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } if (isStochastic && !useOfVariable.ContainsKey(ivd)) { Error("cannot find use channel of " + ivd); return; } IExpression marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]); IExpression useExpr = isStochastic ? Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]) : marginalExpr; InitialiseTo it = context.InputAttributes.Get <InitialiseTo>(ivd); Type[] genArgs = new Type[] { defExpr.GetExpressionType() }; if (rhs is IMethodInvokeExpression) { IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs; if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)) && ancIndex < context.InputStack.Count - 2) { IExpression arg = imie.Arguments[0]; IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg); if (ivd2 != null && context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2)) { // if a variable is a copy, use the original expression since it will give more precise dependencies. defExpr = arg; shouldDelete = true; bool makeClone = false; if (makeClone) { VariableInformation vi2 = VariableInformation.GetVariableInformation(context, ivd2); IList <IStatement> stmts = Builder.StmtCollection(); List <IList <IExpression> > indices = Recognizer.GetIndices(defExpr); IVariableDeclaration useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices); useExpr2 = Builder.VarRefExpr(useDecl2); Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2); int ancIndex2 = defContainers2.GetMatchingAncestorIndex(context); Containers missing2 = defContainers2.GetContainersNotInContext(context, ancIndex2); stmts = Containers.WrapWithContainers(stmts, missing2.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts); context.InputAttributes.Set(useDecl2, defContainers2); // TODO: call CreateUseChannel ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2); usageChannel.decl = useDecl2; context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2); context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2); context.OutputAttributes.Set(useDecl2, usageChannel); //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'")); context.OutputAttributes.Remove <InitialiseTo>(vi.declaration); IExpression copyExpr = Builder.StaticGenericMethod( new Func <PlaceHolder, PlaceHolder>(Clone.Copy), genArgs, useExpr2); var copyStmt = Builder.AssignStmt(useExpr, copyExpr); context.AddStatementAfterCurrent(copyStmt); } } } } // Add the variable factor IExpression variableFactorExpr; bool isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd); if (isGateExitRandom) { variableFactorExpr = Builder.StaticGenericMethod( new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable), genArgs, defExpr, marginalExpr); } else { Delegate d = algorithm.GetVariableFactor(isDerived, it != null); if (isPointEstimate) { d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint); } if (it == null) { variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr); } else { IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression); variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr); } } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr); context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr); if (isStochastic) { context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor()); } var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr); context.AddStatementAfterCurrent(assignStmt); }
protected IExpression ConvertWithReplication(IExpression expr) { IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // Check if this is an index local variable if (baseVar == null) { return(expr); } // Check if the variable is stochastic if (!CodeRecognizer.IsStochastic(context, baseVar)) { return(expr); } if (cutVariables.Contains(baseVar)) { return(expr); } // Get the repeat context for this variable RepeatContext lc = context.InputAttributes.Get <RepeatContext>(baseVar); if (lc == null) { Error("Repeat context not found for '" + baseVar.Name + "'."); return(expr); } // Get the reference loop context for this expression var rlc = lc.GetReferenceRepeatContext(context); // If the reference is in the same loop context as the declaration, do nothing. if (rlc.repeats.Count == 0) { return(expr); } // Determine if this expression is being defined (is on the LHS of an assignment) bool isDef = Recognizer.IsBeingMutated(context, expr); Containers containers = context.InputAttributes.Get <Containers>(baseVar); for (int currentRepeat = 0; currentRepeat < rlc.repeatCounts.Count; currentRepeat++) { IExpression repeatCount = rlc.repeatCounts[currentRepeat]; IRepeatStatement repeat = rlc.repeats[currentRepeat]; // must replicate across this loop. if (isDef) { Error("Cannot re-define a variable in a repeat block."); continue; } // are we replicating the argument of Gate.Cases? IMethodInvokeExpression imie = context.FindAncestor <IMethodInvokeExpression>(); if (imie != null) { if (Recognizer.IsStaticMethod(imie, typeof(Gate), "Cases")) { Error("'if(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'"); } if (Recognizer.IsStaticMethod(imie, typeof(Gate), "CasesInt")) { Error("'case(" + expr + ")' or 'switch(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'"); } } VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); IList <IStatement> stmts = Builder.StmtCollection(); List <IList <IExpression> > inds = Recognizer.GetIndices(expr); IVariableDeclaration repVar = varInfo.DeriveIndexedVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rpt"), inds); if (!context.InputAttributes.Has <DerivedVariable>(repVar)) { context.OutputAttributes.Set(repVar, new DerivedVariable()); } if (context.InputAttributes.Has <ChannelInfo>(baseVar)) { VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar); ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo); ci.decl = repVar; context.OutputAttributes.Set(repVar, ci); } // set the RepeatContext of repVar to include all repeats up to this one (so that it doesn't get Entered again) List <IRepeatStatement> repeats = new List <IRepeatStatement>(lc.repeats); for (int i = 0; i <= currentRepeat; i++) { repeats.Add(rlc.repeats[i]); } context.OutputAttributes.Remove <RepeatContext>(repVar); context.OutputAttributes.Set(repVar, new RepeatContext(repeats)); // Create replicate factor Type returnType = Builder.ToType(repVar.VariableType); IMethodInvokeExpression powerPlateMethod = Builder.StaticGenericMethod( new Func <PlaceHolder, double, PlaceHolder>(PowerPlate.Enter), new Type[] { returnType }, expr, repeatCount); IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), powerPlateMethod); // Copy attributes across from variable to replication expression context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, powerPlateMethod); context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(baseVar, context.OutputAttributes, powerPlateMethod); context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(baseVar, context.OutputAttributes, powerPlateMethod); stmts.Add(Builder.ExprStatement(assignExpression)); // add any containers missing from context. containers = new Containers(context); // remove inner repeats for (int i = currentRepeat + 1; i < rlc.repeatCounts.Count; i++) { containers = containers.RemoveOneRepeat(rlc.repeats[i]); } context.OutputAttributes.Set(repVar, containers); containers = containers.RemoveOneRepeat(repeat); containers = Containers.RemoveUnusedLoops(containers, context, powerPlateMethod); if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count); //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; } int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.inputs); // must convert the output since it may contain 'if' conditions context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true); baseVar = repVar; expr = Builder.VarRefExpr(repVar); } return(expr); }
private IExpression GetClone(IExpression expr) { // Determine if this is a definition i.e. the variable is on the left hand side of an assignment // This must be done before base.ConvertArrayIndexer changes the expression! if (Recognizer.IsBeingIndexed(context)) { return(expr); } IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // If not an indexed variable reference, skip it (e.g. an indexed argument reference) if (baseVar == null) { return(expr); } DepthInfo depthInfo; if (!analysis.depthInfos.TryGetValue(baseVar, out depthInfo)) { return(expr); } if (depthInfo.useCount <= 1 || depthInfo.indexInfoOfDepth.Count == 1) { return(expr); } bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(baseVar); if (isEvidenceVar && !cloneEvidenceVars) { return(expr); } int depth = Recognizer.GetIndexingDepth(expr); IndexInfo info = depthInfo.indexInfoOfDepth[depth]; if (info.clone == null) { if (depth == depthInfo.definitionDepth) { return(expr); } VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); string name = baseVar.Name + "_depth" + depth; IList <IStatement> stmts = Builder.StmtCollection(); // declare the clone varInfo.DefineAllIndexVars(context); IVariableDeclaration newvd = varInfo.DeriveIndexedVariable(stmts, context, name); VariableInformation newVariableInformation = VariableInformation.GetVariableInformation(context, newvd); newVariableInformation.LiteralIndexingDepth = info.literalIndexingDepth; if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } if (context.InputAttributes.Has <ChannelInfo>(baseVar)) { ChannelInfo ci = ChannelInfo.UseChannel(varInfo); ci.decl = newvd; context.OutputAttributes.Set(newvd, ci); } // define the clone // if depth < definitionDepth, index by the definitionDepth // else index by depth // e.g. // x[i] = definition // x_depth0[i] = Copy(x[i]) // x_depth2[i][j] = Copy(x[i][j]) int indexingDepth = System.Math.Max(depth, depthInfo.definitionDepth); IExpression lhs = Builder.VarRefExpr(newvd); // TODO: clone the next lower depth, not always the baseVar IExpression rhs = Builder.VarRefExpr(baseVar); AddCopyStatements(stmts, newVariableInformation, indexingDepth, lhs, rhs); // Reduce memory consumption by declaring the clone outside of unnecessary loops. // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item. Containers containers = info.containers; containers = Containers.RemoveUnusedLoops(containers, context, Builder.VarRefExpr(baseVar)); if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible, // then wrap the declaration with the remaining containers. int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); context.InputAttributes.Set(newvd, containers); info.clone = Builder.VarRefExpr(newvd); } List <IList <IExpression> > indices = Recognizer.GetIndices(expr); return(Builder.JaggedArrayIndex(info.clone, indices)); }