protected override IExpression ConvertVariableDeclExpr(IVariableDeclarationExpression ivde) { context.InputAttributes.Remove <Containers>(ivde.Variable); context.InputAttributes.Set(ivde.Variable, new Containers(context)); IVariableDeclaration ivd = ivde.Variable; if (!CodeRecognizer.IsStochastic(context, ivd)) { ProcessConstant(ivd); if (!context.InputAttributes.Has <DescriptionAttribute>(ivd)) { context.OutputAttributes.Set(ivd, new DescriptionAttribute("The constant '" + ivd.Name + "'")); } return(ivde); } VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); // Ensure the marginal prototype is set. MarginalPrototype mpa = Context.InputAttributes.Get <MarginalPrototype>(ivd); try { vi.SetMarginalPrototypeFromAttribute(mpa); } catch (ArgumentException ex) { Error(ex.Message); } return(ivde); }
internal static Containers SortStochasticConditionals(Containers containers, BasicTransformContext context) { Containers result = new Containers(); Containers conditionals = new Containers(); for (int i = 0; i < containers.inputs.Count; i++) { IStatement container = containers.inputs[i]; if (container is IConditionStatement) { IConditionStatement ics = (IConditionStatement)container; if (CodeRecognizer.IsStochastic(context, ics.Condition)) { conditionals.inputs.Add(container); conditionals.outputs.Add(containers.outputs[i]); continue; } } result.inputs.Add(container); result.outputs.Add(containers.outputs[i]); } for (int i = 0; i < conditionals.inputs.Count; i++) { result.inputs.Add(conditionals.inputs[i]); result.outputs.Add(conditionals.outputs[i]); } return(result); }
private List <ConditionBinding> FilterConditionContext(GateBlock gateBlock, List <ConditionBinding> conditionContext) { return(conditionContext .Where(binding => !CodeRecognizer.IsStochastic(context, binding.lhs) && !ContainsLocalVars(gateBlock, binding.lhs) && !ContainsLocalVars(gateBlock, binding.rhs)) .ToList()); }
protected override IStatement ConvertExpressionStatement(IExpressionStatement ies) { if (parent == null) { return(ies); } bool keepIfStatement = false; // Only keep the surrounding if statement when a factor or constraint is being added. IExpression expr = ies.Expression; if (expr is IMethodInvokeExpression) { keepIfStatement = true; if (CodeRecognizer.IsInfer(expr)) { keepIfStatement = false; } } else if (expr is IAssignExpression) { keepIfStatement = false; IAssignExpression iae = (IAssignExpression)expr; IMethodInvokeExpression imie = iae.Expression as IMethodInvokeExpression; if (imie != null) { keepIfStatement = true; if (imie.Arguments.Count > 0) { // Statements that copy evidence variables should not send evidence messages. IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target); IVariableDeclaration ivdArg = Recognizer.GetVariableDeclaration(imie.Arguments[0]); if (ivd != null && context.InputAttributes.Has <DoNotSendEvidence>(ivd) && ivdArg != null && context.InputAttributes.Has <DoNotSendEvidence>(ivdArg)) { keepIfStatement = false; } } } else { expr = iae.Target; } } if (expr is IVariableDeclarationExpression) { IVariableDeclarationExpression ivde = (IVariableDeclarationExpression)expr; IVariableDeclaration ivd = ivde.Variable; keepIfStatement = CodeRecognizer.IsStochastic(context, ivd) && !context.InputAttributes.Has <DoNotSendEvidence>(ivd); } if (!keepIfStatement) { return(ies); } IConditionStatement cs = Builder.CondStmt(parent.Condition, Builder.BlockStmt()); cs.Then.Statements.Add(ies); return(cs); }
/// <summary> /// Returns true if ivd is constant for a given value of the loop variables in lc /// </summary> /// <param name="ivd"></param> /// <param name="constantLoopVars"></param> /// <returns></returns> private bool IsConstantWrtLoops(IVariableDeclaration ivd, Set<IVariableDeclaration> constantLoopVars) { if (ivd == null) return true; if (CodeRecognizer.IsStochastic(context, ivd)) return false; if (constantLoopVars.Contains(ivd)) return true; LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd); // return false if lc2 has any loops that are not in constantLoopVars return constantLoopVars.ContainsAll(lc2.loopVariables); }
private bool ArgumentIsPointMass(IExpression arg) { bool IsOut = (arg is IAddressOutExpression); if (CodeRecognizer.IsStochastic(context, arg) && !IsOut) { IVariableDeclaration argVar = Recognizer.GetVariableDeclaration(arg); return((argVar != null) && context.InputAttributes.Has <ForwardPointMass>(argVar)); } else { return(true); } }
/// <summary> /// Raise an error if any expression is stochastic. /// </summary> /// <param name="exprs"></param> private void CheckIndicesAreNotStochastic(IList <IExpression> exprs) { foreach (IExpression index in exprs) { foreach (var ivd in Recognizer.GetVariables(index)) { if (CodeRecognizer.IsStochastic(context, ivd)) { string msg = "Indexing by a random variable '" + ivd.Name + "'. You must wrap this statement with Variable.Switch(" + index + ")"; Error(msg); } } } }
internal static List <ConditionBinding> GetBindings(BasicTransformContext context, IEnumerable <IStatement> containers) { List <ConditionBinding> bindings = new List <ConditionBinding>(); foreach (IStatement st in containers) { if (st is IConditionStatement) { IConditionStatement ics = (IConditionStatement)st; if (!CodeRecognizer.IsStochastic(context, ics.Condition)) { ConditionBinding binding = new ConditionBinding(ics.Condition); bindings.Add(binding); } } } return(bindings); }
protected void ProcessExpression(IExpression expr) { bool isDef = Recognizer.IsBeingMutated(context, expr); IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr); if (ivd == null) { return; } if (!CodeRecognizer.IsStochastic(context, ivd)) { return; } if (isDef && !Recognizer.IsBeingAllocated(context, expr)) { RegisterDefinition(ivd); } ProcessUse(expr, isDef, conditionContext); }
/// <summary> /// Analyse the condition body using an augmented conditionContext /// </summary> /// <param name="ics"></param> /// <returns></returns> protected override IStatement ConvertCondition(IConditionStatement ics) { if (CodeRecognizer.IsStochastic(context, ics.Condition)) { return(base.ConvertCondition(ics)); } // ics.Condition is not stochastic context.SetPrimaryOutput(ics); ConvertExpression(ics.Condition); ConditionBinding binding = new ConditionBinding(ics.Condition); int startIndex = conditionContext.Count; conditionContext.Add(binding); ConvertBlock(ics.Then); if (ics.Else != null) { conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); binding = binding.FlipCondition(); conditionContext.Add(binding); ConvertBlock(ics.Else); } conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); return(ics); }
/// <summary> /// Adds a Replicate statement to the output code, if needed. /// </summary> /// <param name="target">LHS of assignment</param> /// <param name="rhs">RHS of assignment</param> /// <param name="shouldDelete">True if the original assignment statement should be deleted, i.e. it can be optimized away.</param> protected void AddReplicateStatement(IExpression target, IExpression rhs, ref bool shouldDelete) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd == null) { return; } VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); if (!vi.IsStochastic) { ProcessConstant(ivd); return; } bool isEvidenceVar = context.InputAttributes.Has <DoNotSendEvidence>(ivd); if (SwapIndices && replicateEvidenceVars != isEvidenceVar) { return; } // iae defines a stochastic variable ChannelAnalysisTransform.UsageInfo usageInfo; if (!analysis.usageInfo.TryGetValue(ivd, out usageInfo)) { return; } int useCount = usageInfo.NumberOfUses; if (useCount <= 1) { return; } VariableToChannelInformation vtci; bool firstTime = !usesOfVariable.TryGetValue(ivd, out vtci); int targetDepth = Recognizer.GetIndexingDepth(target); int minDepth = usageInfo.indexingDepths[0]; int usageDepth = minDepth; if (!SwapIndices) { usageDepth = 0; } Containers defContainers = context.InputAttributes.Get <Containers>(ivd); int ancIndex = defContainers.GetMatchingAncestorIndex(context); Containers missing = defContainers.GetContainersNotInContext(context, ancIndex); if (firstTime) { // declaration of uses array IList <IStatement> stmts = Builder.StmtCollection(); vtci = DeclareUsesArray(stmts, ivd, vi, useCount, usageDepth); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } // check that extra literal indices in the target are zero. // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i]) // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1]) // but not x_uses[i] = Rep(x[i]) since this will be a duplicate. bool extraLiteralsAreZero = CheckExtraLiteralsAreZero(target, targetDepth, usageDepth); if (extraLiteralsAreZero) { // definition of uses array IExpression defExpr = Builder.VarRefExpr(ivd); IExpression usesExpr = Builder.VarRefExpr(vtci.usesDecl); List <IStatement> loops = new List <IStatement>(); if (usageDepth == targetDepth) { defExpr = target; if (defExpr is IVariableDeclarationExpression) { defExpr = Builder.VarRefExpr(ivd); } usesExpr = Builder.ReplaceVariable(defExpr, ivd, vtci.usesDecl); } else { // loops over the last indexing bracket for (int d = 0; d < usageDepth; d++) { List <IExpression> indices = new List <IExpression>(); for (int i = 0; i < vi.sizes[d].Length; i++) { IVariableDeclaration v = vi.indexVars[d][i]; IStatement loop = Builder.ForStmt(v, vi.sizes[d][i]); loops.Add(loop); indices.Add(Builder.VarRefExpr(v)); } defExpr = Builder.ArrayIndex(defExpr, indices); usesExpr = Builder.ArrayIndex(usesExpr, indices); } } if (rhs != null && rhs is IMethodInvokeExpression) { IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs; bool copyPropagation = false; if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Factor.Copy)) && copyPropagation) { // if a variable is a copy, use the original expression since it will give more precise dependencies. defExpr = imie.Arguments[0]; shouldDelete = true; } } // Add the statement: // x_uses = Replicate(x, useCount) var genArgs = new Type[] { ivd.VariableType.DotNetType }; IMethodInvokeExpression repMethod = Builder.StaticGenericMethod( new Func <PlaceHolder, int, PlaceHolder[]>(Factor.Replicate), genArgs, defExpr, Builder.LiteralExpr(useCount)); bool isGateExitRandom = context.InputAttributes.Has <Algorithms.VariationalMessagePassing.GateExitRandomVariable>(ivd); if (isGateExitRandom) { repMethod = Builder.StaticGenericMethod( new Func <PlaceHolder, int, PlaceHolder[]>(Gate.ReplicateExiting), genArgs, defExpr, Builder.LiteralExpr(useCount)); } context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, repMethod); if (context.InputAttributes.Has <DivideMessages>(ivd)) { context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(ivd, context.OutputAttributes, repMethod); } else if (useCount == 2) { // division has no benefit for 2 uses, and degrades the schedule context.OutputAttributes.Set(repMethod, new DivideMessages(false)); } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, repMethod); IStatement repSt = Builder.AssignStmt(usesExpr, repMethod); if (usageDepth == targetDepth) { if (isEvidenceVar) { // place the Replicate after the current statement, but outside of any evidence conditionals ancIndex = context.FindAncestorIndex <IStatement>(); for (int i = ancIndex - 2; i > 0; i--) { object ancestor = context.GetAncestor(i); if (ancestor is IConditionStatement) { IConditionStatement ics = (IConditionStatement)ancestor; if (CodeRecognizer.IsStochastic(context, ics.Condition)) { ancIndex = i; break; } } } context.AddStatementAfterAncestorIndex(ancIndex, repSt); } else { context.AddStatementAfterCurrent(repSt); } } else if (firstTime) { // place Replicate after assignment but outside of definition loops (can't have any uses there) repSt = Containers.WrapWithContainers(repSt, loops); repSt = Containers.WrapWithContainers(repSt, missing.outputs); context.AddStatementAfterAncestorIndex(ancIndex, repSt); } } }
/// <summary> /// This method does all the work of converting literal indexing expressions. /// </summary> /// <param name="iaie"></param> /// <returns></returns> protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie) { IndexAnalysisTransform.IndexInfo info; if (!analysis.indexInfoOf.TryGetValue(iaie, out info)) { return(base.ConvertArrayIndexer(iaie)); } // Determine if this is a definition i.e. the variable is on the left hand side of an assignment // This must be done before base.ConvertArrayIndexer changes the expression! bool isDef = Recognizer.IsBeingMutated(context, iaie); if (info.clone != null) { if (isDef) { // check that extra literal indices in the target are zero. // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i]) // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1]) // but not x_uses[i] = Rep(x[i]) since this will be a duplicate. bool extraLiteralsAreZero = true; int parentIndex = context.InputStack.Count - 2; object parent = context.GetAncestor(parentIndex); while (parent is IArrayIndexerExpression) { IArrayIndexerExpression parent_iaie = (IArrayIndexerExpression)parent; foreach (IExpression index in parent_iaie.Indices) { if (index is ILiteralExpression) { int value = (int)((ILiteralExpression)index).Value; if (value != 0) { extraLiteralsAreZero = false; break; } } } parentIndex--; parent = context.GetAncestor(parentIndex); } if (false && extraLiteralsAreZero) { // change: // array[0] = f() // into: // array_item0 = f() // array[0] = Copy(array_item0) IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { iaie.GetExpressionType() }, info.clone); IStatement copySt = Builder.AssignStmt(iaie, copy); context.AddStatementAfterCurrent(copySt); } } return(info.clone); } if (isDef) { // do not clone the lhs of an array create assignment. IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>(); if (assignExpr.Expression is IArrayCreateExpression) { return(iaie); } } IVariableDeclaration originalBaseVar = Recognizer.GetVariableDeclaration(iaie); // If the variable is not stochastic, return if (!CodeRecognizer.IsStochastic(context, originalBaseVar)) { return(iaie); } IExpression newExpr = null; IVariableDeclaration baseVar = originalBaseVar; IVariableDeclaration newvd = null; IExpression rhsExpr = null; Containers containers = info.containers; Type tp = iaie.GetExpressionType(); if (tp == null) { Error("Could not determine type of expression: " + iaie); return(iaie); } var stmts = Builder.StmtCollection(); var stmtsAfter = Builder.StmtCollection(); // does the expression have the form array[indices[k]][indices2[k]][indices3[k]]? if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index3 = (IArrayIndexerExpression)iaie.Indices[0]; IArrayIndexerExpression iaie2 = (IArrayIndexerExpression)iaie.Target; if (index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression && iaie2.Target is IArrayIndexerExpression && iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie2.Indices[0]; IArrayIndexerExpression iaie3 = (IArrayIndexerExpression)iaie2.Target; if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression && iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index = (IArrayIndexerExpression)iaie3.Indices[0]; IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0]; IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (index.Indices.Count == 1 && index2.Indices[0].Equals(innerIndex) && index3.Indices[0].Equals(innerIndex) && innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target)) { // expression has the form array[indices[k]][indices2[k]][indices3[k]] if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } WarnIfLocal(index.Target, iaie3.Target, iaie); WarnIfLocal(index2.Target, iaie3.Target, iaie); WarnIfLocal(index3.Target, iaie3.Target, iaie); containers = RemoveReferencesTo(containers, innerIndex); IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop); var indices = Recognizer.GetIndices(iaie); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, iaie3); AppendIndexString(sb, iaie2); AppendIndexString(sb, iaie); string name = ToString(iaie3.Target) + sb.ToString(); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged), new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target); context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex); rhsExpr = getItems; } } } } // does the expression have the form array[indices[k]][indices2[k]]? if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie.Indices[0]; IArrayIndexerExpression target = (IArrayIndexerExpression)iaie.Target; if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression && target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression) { IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index2.Indices[0]; IArrayIndexerExpression index = (IArrayIndexerExpression)target.Indices[0]; IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) && innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target)) { // expression has the form array[indices[k]][indices2[k]] if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } var innerLoops = new List <IForStatement>(); innerLoops.Add(innerLoop); var indexTarget = index.Target; var index2Target = index2.Target; // check if the index array is jagged, i.e. array[indices[k][j]] while (indexTarget is IArrayIndexerExpression && index2Target is IArrayIndexerExpression) { IArrayIndexerExpression indexTargetExpr = (IArrayIndexerExpression)indexTarget; IArrayIndexerExpression index2TargetExpr = (IArrayIndexerExpression)index2Target; if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression && index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression) { IVariableReferenceExpression innerIndexTarget = (IVariableReferenceExpression)indexTargetExpr.Indices[0]; IVariableReferenceExpression innerIndex2Target = (IVariableReferenceExpression)index2TargetExpr.Indices[0]; IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget); if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) && innerIndexTarget.Equals(innerIndex2Target)) { innerLoops.Add(indexTargetLoop); indexTarget = indexTargetExpr.Target; index2Target = index2TargetExpr.Target; } else { break; } } else { break; } } WarnIfLocal(indexTarget, target.Target, iaie); WarnIfLocal(index2Target, target.Target, iaie); innerLoops.Reverse(); var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) }); var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) }); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, target); AppendIndexString(sb, iaie); string name = ToString(target.Target) + sb.ToString(); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); var indices = Recognizer.GetIndices(iaie); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems; if (innerLoops.Count == 1) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged), new Type[] { tp }, target.Target, indexTarget, index2Target); } else if (innerLoops.Count == 2) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged), new Type[] { tp }, target.Target, indexTarget, index2Target); } else { throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}"); } context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i]))); newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices); rhsExpr = getItems; } else if (HasAnyCommonLoops(index, index2)) { Warning($"This model will consume excess memory due to the indexing expression {iaie} since {index} and {index2} have larger depth than the compiler can handle."); } } } if (newvd == null) { IArrayIndexerExpression originalExpr = iaie; if (UseGetItems) { iaie = (IArrayIndexerExpression)base.ConvertArrayIndexer(iaie); } if (!object.ReferenceEquals(iaie.Target, originalExpr.Target) && false) { // TODO: determine if this warning is useful or not string warningText = "This model may consume excess memory due to the jagged indexing expression {0}"; Warning(string.Format(warningText, originalExpr)); } // get the baseVar of the new expression. baseVar = Recognizer.GetVariableDeclaration(iaie); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); var indices = Recognizer.GetIndices(iaie); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, iaie); string name = ToString(iaie.Target) + sb.ToString(); // does the expression have the form array[indices[k]]? if (UseGetItems && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression) { IArrayIndexerExpression index = (IArrayIndexerExpression)iaie.Indices[0]; if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression) { // expression has the form array[indices[k]] IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0]; IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie.Target, index.Target)) { if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } var innerLoops = new List <IForStatement>(); innerLoops.Add(innerLoop); var indexTarget = index.Target; // check if the index array is jagged, i.e. array[indices[k][j]] while (indexTarget is IArrayIndexerExpression) { IArrayIndexerExpression index2 = (IArrayIndexerExpression)indexTarget; if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression) { IVariableReferenceExpression innerIndex2 = (IVariableReferenceExpression)index2.Indices[0]; IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2); if (innerLoop2 != null && AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target)) { innerLoops.Add(innerLoop2); indexTarget = index2.Target; // This limit must match the number of handled cases below. if (innerLoops.Count == 3) { break; } } else { break; } } else { break; } } WarnIfLocal(indexTarget, iaie.Target, originalExpr); innerLoops.Reverse(); var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) }); var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) }); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems; if (innerLoops.Count == 1) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems), new Type[] { tp }, iaie.Target, indexTarget); } else if (innerLoops.Count == 2) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems), new Type[] { tp }, iaie.Target, indexTarget); } else if (innerLoops.Count == 3) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems), new Type[] { tp }, iaie.Target, indexTarget); } else { throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}"); } context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i]))); newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices); rhsExpr = getItems; } } } if (newvd == null) { if (UseGetItems && info.count < 2) { return(iaie); } try { newvd = varInfo.DeriveIndexedVariable(stmts, context, name, indices, copyInitializer: isDef); } catch (Exception ex) { Error(ex.Message, ex); return(iaie); } context.OutputAttributes.Remove <DerivedVariable>(newvd); newExpr = Builder.VarRefExpr(newvd); rhsExpr = iaie; if (isDef) { // change: // array[0] = f() // into: // array_item0 = f() // array[0] = Copy(array_item0) IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, newExpr); IStatement copySt = Builder.AssignStmt(iaie, copy); stmtsAfter.Add(copySt); if (!context.InputAttributes.Has <DerivedVariable>(baseVar)) { context.InputAttributes.Set(baseVar, new DerivedVariable()); } } else if (!info.IsAssignedTo) { // change: // x = f(array[0]) // into: // array_item0 = Copy(array[0]) // x = f(array_item0) IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy), new Type[] { tp }, iaie); IStatement copySt = Builder.AssignStmt(Builder.VarRefExpr(newvd), copy); //if (attr != null) context.OutputAttributes.Set(copySt, attr); stmts.Add(copySt); context.InputAttributes.Set(newvd, new DerivedVariable()); } } } // Reduce memory consumption by declaring the clone outside of unnecessary loops. // This way, the item is cloned outside the loop and then replicated, instead of replicating the entire array and cloning the item. containers = Containers.RemoveUnusedLoops(containers, context, rhsExpr); if (context.InputAttributes.Has <DoNotSendEvidence>(originalBaseVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } if (true) { IStatement st = GetBindingSetContainer(FilterBindingSet(info.bindings, binding => Containers.ContainsExpression(containers.inputs, context, binding.GetExpression()))); if (st != null) { containers.Add(st); } } // To put the declaration in the desired containers, we find an ancestor which includes as many of the containers as possible, // then wrap the declaration with the remaining containers. int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); stmtsAfter = Containers.WrapWithContainers(stmtsAfter, missing.outputs); context.AddStatementsAfterAncestorIndex(ancIndex, stmtsAfter); context.InputAttributes.Set(newvd, containers); info.clone = newExpr; return(newExpr); }
protected IExpression ConvertWithReplication(IExpression expr) { IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // Check if this is an index local variable if (baseVar == null) return expr; // Check if the variable is stochastic if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr; // Get the loop context for this variable LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar); if (lc == null) { Error("Loop context not found for '" + baseVar.Name + "'."); return expr; } // Get the reference loop context for this expression RefLoopContext rlc = lc.GetReferenceLoopContext(context); // If the reference is in the same loop context as the declaration, do nothing. if (rlc.loops.Count == 0) return expr; // the set of loop variables that are constant wrt the expr Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>(); constantLoopVars.AddRange(lc.loopVariables); // collect set of all loop variable indices in the expression Set<int> embeddedLoopIndices = new Set<int>(); List<IList<IExpression>> brackets = Recognizer.GetIndices(expr); foreach (IList<IExpression> bracket in brackets) { foreach (IExpression index in bracket) { IExpression indExpr = index; if (indExpr is IBinaryExpression ibe) { indExpr = ibe.Left; } IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr); if (indVar != null) { if (!constantLoopVars.Contains(indVar)) { int loopIndex = rlc.loopVariables.IndexOf(indVar); if (loopIndex != -1) { // indVar is a loop variable constantLoopVars.Add(rlc.loopVariables[loopIndex]); } else { // indVar is not a loop variable LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar); foreach (var ivd in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd} is not in {rlc} for expression {expr}"); } } } } } else { foreach(var ivd in Recognizer.GetVariables(indExpr)) { if (!constantLoopVars.Contains(ivd)) { // copied from above LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd); foreach (var ivd2 in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd2)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd2); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd2} is not in {rlc} for expression {expr}"); } } } } } } } // Find loop variables that must be constant due to condition statements. List<IStatement> ancestors = context.FindAncestors<IStatement>(); foreach (IStatement ancestor in ancestors) { if (!(ancestor is IConditionStatement ics)) continue; ConditionBinding binding = new ConditionBinding(ics.Condition); IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs); IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs); int index = rlc.loopVariables.IndexOf(ivd); if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars)) { constantLoopVars.Add(ivd); continue; } int index2 = rlc.loopVariables.IndexOf(ivd2); if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars)) { constantLoopVars.Add(ivd2); continue; } } // Determine if this expression is being defined (is on the LHS of an assignment) bool isDef = Recognizer.IsBeingMutated(context, expr); Containers containers = context.InputAttributes.Get<Containers>(baseVar); IExpression originalExpr = expr; for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++) { IVariableDeclaration loopVar = rlc.loopVariables[currentLoop]; if (constantLoopVars.Contains(loopVar)) continue; IForStatement loop = rlc.loops[currentLoop]; // must replicate across this loop. if (isDef) { Error("Cannot re-define a variable in a loop. Variables on the left hand side of an assignment must be indexed by all containing loops."); continue; } if (embeddedLoopIndices.Contains(currentLoop)) { string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays."; Warning(string.Format(warningText, originalExpr, loopVar.Name)); } // split expr into a target and extra indices, where target will be replicated and extra indices will be added later var extraIndices = new List<IEnumerable<IExpression>>(); AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); IExpression loopSize = Recognizer.LoopSizeExpression(loop); IList<IStatement> stmts = Builder.StmtCollection(); List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate); IVariableDeclaration newIndexVar = loopVar; // if loopVar is already an indexVar of varInfo, create a new variable if (varInfo.HasIndexVar(loopVar)) { newIndexVar = VariableInformation.GenerateLoopVar(context, "_a"); context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar); } IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"), loopSize, newIndexVar, inds, useArrays: true); if (!context.InputAttributes.Has<DerivedVariable>(repVar)) context.OutputAttributes.Set(repVar, new DerivedVariable()); if (context.InputAttributes.Has<ChannelInfo>(baseVar)) { VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar); ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo); ci.decl = repVar; context.OutputAttributes.Set(repVar, ci); } // Create replicate factor Type returnType = Builder.ToType(repVar.VariableType); IMethodInvokeExpression repMethod = Builder.StaticGenericMethod( new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate), new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize); IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod); // Copy attributes across from variable to replication expression context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod); stmts.Add(Builder.ExprStatement(assignExpression)); // add any containers missing from context. containers = new Containers(context); // RemoveUnusedLoops will also remove conditionals involving those loop variables. // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms. containers = Containers.RemoveUnusedLoops(containers, context, repMethod); if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context); //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count); //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; } if (containers.Contains(loop)) { Error("Internal: invalid containers for replicating " + baseVar); break; } int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.inputs); context.OutputAttributes.Set(repVar, containers); List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex); foreach (IStatement container in missing.inputs) { if (container is IForStatement ifs) loops.Add(ifs); } context.OutputAttributes.Set(repVar, new LoopContext(loops)); // must convert the output since it may contain 'if' conditions context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true); baseVar = repVar; expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar)); expr = Builder.JaggedArrayIndex(expr, extraIndices); } return expr; }
protected void ProcessAssign(IExpression target, IExpression rhs, ref bool shouldDelete) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd == null) { return; } if (rhs is IArrayCreateExpression) { IArrayCreateExpression iace = (IArrayCreateExpression)rhs; bool zeroLength = iace.Dimensions.All(dimExpr => (dimExpr is ILiteralExpression) && ((ILiteralExpression)dimExpr).Value.Equals(0)); if (!zeroLength && iace.Initializer == null) { return; // variable will have assignments to elements } } bool firstTime = !variablesAssigned.Contains(ivd); variablesAssigned.Add(ivd); bool isInferred = context.InputAttributes.Has <IsInferred>(ivd); bool isStochastic = CodeRecognizer.IsStochastic(context, ivd); if (!isStochastic) { return; } VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); Containers defContainers = context.InputAttributes.Get <Containers>(ivd); int ancIndex = defContainers.GetMatchingAncestorIndex(context); Containers missing = defContainers.GetContainersNotInContext(context, ancIndex); // definition of a stochastic variable IExpression lhs = target; if (lhs is IVariableDeclarationExpression) { lhs = Builder.VarRefExpr(ivd); } IExpression defExpr = lhs; if (firstTime && isStochastic) { // Create a ChannelInfo attribute for use by later transforms, e.g. MessageTransform ChannelInfo defChannel = ChannelInfo.DefChannel(vi); defChannel.decl = ivd; context.OutputAttributes.Set(ivd, defChannel); } bool isDerived = context.InputAttributes.Has <DerivedVariable>(ivd); IAlgorithm algorithm = this.algorithmDefault; Algorithm algAttr = context.InputAttributes.Get <Algorithm>(ivd); if (algAttr != null) { algorithm = algAttr.algorithm; } if (algorithm is VariationalMessagePassing && ((VariationalMessagePassing)algorithm).UseDerivMessages && isDerived && firstTime) { vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); IVariableDeclaration derivDecl = vi.DeriveIndexedVariable(stmts, context, ivd.Name + "_deriv"); context.OutputAttributes.Set(ivd, new DerivMessage(derivDecl)); ChannelInfo derivChannel = ChannelInfo.DefChannel(vi); derivChannel.decl = derivDecl; context.OutputAttributes.Set(derivChannel.decl, derivChannel); context.OutputAttributes.Set(derivChannel.decl, new DescriptionAttribute("deriv of '" + ivd.Name + "'")); // Add the declarations stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } bool isPointEstimate = context.InputAttributes.Has <PointEstimate>(ivd); if (this.analysis.variablesExcludingVariableFactor.Contains(ivd)) { this.variablesLackingVariableFactor.Add(ivd); // ivd will get a marginal channel in ConvertMethodInvoke useOfVariable[ivd] = ivd; return; } if (isDerived && !isInferred && !isPointEstimate) { return; } IExpression useExpr2 = null; if (firstTime) { // create marginal and use channels vi.DefineAllIndexVars(context); IList <IStatement> stmts = Builder.StmtCollection(); CreateMarginalChannel(ivd, vi, stmts); if (isStochastic) { CreateUseChannel(ivd, vi, stmts); context.InputAttributes.Set(useOfVariable[ivd], defContainers); } // Add the declarations stmts = Containers.WrapWithContainers(stmts, missing.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex, stmts); } if (isStochastic && !useOfVariable.ContainsKey(ivd)) { Error("cannot find use channel of " + ivd); return; } IExpression marginalExpr = Builder.ReplaceVariable(lhs, ivd, marginalOfVariable[ivd]); IExpression useExpr = isStochastic ? Builder.ReplaceVariable(lhs, ivd, useOfVariable[ivd]) : marginalExpr; InitialiseTo it = context.InputAttributes.Get <InitialiseTo>(ivd); Type[] genArgs = new Type[] { defExpr.GetExpressionType() }; if (rhs is IMethodInvokeExpression) { IMethodInvokeExpression imie = (IMethodInvokeExpression)rhs; if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)) && ancIndex < context.InputStack.Count - 2) { IExpression arg = imie.Arguments[0]; IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(arg); if (ivd2 != null && context.InputAttributes.Get <MarginalPrototype>(ivd) == context.InputAttributes.Get <MarginalPrototype>(ivd2)) { // if a variable is a copy, use the original expression since it will give more precise dependencies. defExpr = arg; shouldDelete = true; bool makeClone = false; if (makeClone) { VariableInformation vi2 = VariableInformation.GetVariableInformation(context, ivd2); IList <IStatement> stmts = Builder.StmtCollection(); List <IList <IExpression> > indices = Recognizer.GetIndices(defExpr); IVariableDeclaration useDecl2 = vi2.DeriveIndexedVariable(stmts, context, ivd2.Name + "_use", indices); useExpr2 = Builder.VarRefExpr(useDecl2); Containers defContainers2 = context.InputAttributes.Get <Containers>(ivd2); int ancIndex2 = defContainers2.GetMatchingAncestorIndex(context); Containers missing2 = defContainers2.GetContainersNotInContext(context, ancIndex2); stmts = Containers.WrapWithContainers(stmts, missing2.outputs); context.AddStatementsBeforeAncestorIndex(ancIndex2, stmts); context.InputAttributes.Set(useDecl2, defContainers2); // TODO: call CreateUseChannel ChannelInfo usageChannel = ChannelInfo.UseChannel(vi2); usageChannel.decl = useDecl2; context.InputAttributes.CopyObjectAttributesTo <InitialiseTo>(vi.declaration, context.OutputAttributes, useDecl2); context.InputAttributes.CopyObjectAttributesTo <DerivMessage>(vi.declaration, context.OutputAttributes, useDecl2); context.OutputAttributes.Set(useDecl2, usageChannel); //context.OutputAttributes.Set(useDecl2, new DescriptionAttribute("use of '" + ivd.Name + "'")); context.OutputAttributes.Remove <InitialiseTo>(vi.declaration); IExpression copyExpr = Builder.StaticGenericMethod( new Func <PlaceHolder, PlaceHolder>(Clone.Copy), genArgs, useExpr2); var copyStmt = Builder.AssignStmt(useExpr, copyExpr); context.AddStatementAfterCurrent(copyStmt); } } } } // Add the variable factor IExpression variableFactorExpr; bool isGateExitRandom = context.InputAttributes.Has <VariationalMessagePassing.GateExitRandomVariable>(ivd); if (isGateExitRandom) { variableFactorExpr = Builder.StaticGenericMethod( new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Gate.ExitingVariable), genArgs, defExpr, marginalExpr); } else { Delegate d = algorithm.GetVariableFactor(isDerived, it != null); if (isPointEstimate) { d = new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint); } if (it == null) { variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, marginalExpr); } else { IExpression initExpr = Builder.ReplaceExpression(lhs, Builder.VarRefExpr(ivd), it.initialMessagesExpression); variableFactorExpr = Builder.StaticGenericMethod(d, genArgs, defExpr, initExpr, marginalExpr); } } context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(ivd, context.OutputAttributes, variableFactorExpr); context.InputAttributes.CopyObjectAttributesTo <Algorithm>(ivd, context.OutputAttributes, variableFactorExpr); if (isStochastic) { context.OutputAttributes.Set(variableFactorExpr, new IsVariableFactor()); } var assignStmt = Builder.AssignStmt(useExpr2 == null ? useExpr : useExpr2, variableFactorExpr); context.AddStatementAfterCurrent(assignStmt); }
/// <summary> /// This method does all the work of converting literal indexing expressions. /// </summary> /// <param name="iaie"></param> /// <returns></returns> protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie) { IndexAnalysisTransform.IndexInfo info; if (!analysis.indexInfoOf.TryGetValue(iaie, out info)) { return(base.ConvertArrayIndexer(iaie)); } // Determine if this is a definition i.e. the variable is on the left hand side of an assignment // This must be done before base.ConvertArrayIndexer changes the expression! bool isDef = Recognizer.IsBeingMutated(context, iaie); if (info.clone != null) { if (isDef) { // check that extra literal indices in the target are zero. // for example, if iae is x[i][0] = (...) then it is safe to add x_uses[i] = Rep(x[i]) // if iae is x[i][1] = (...) then it is safe to add x_uses[i][1] = Rep(x[i][1]) // but not x_uses[i] = Rep(x[i]) since this will be a duplicate. bool extraLiteralsAreZero = true; int parentIndex = context.InputStack.Count - 2; object parent = context.GetAncestor(parentIndex); while (parent is IArrayIndexerExpression parent_iaie) { foreach (IExpression index in parent_iaie.Indices) { if (index is ILiteralExpression ile) { int value = (int)ile.Value; if (value != 0) { extraLiteralsAreZero = false; break; } } } parentIndex--; parent = context.GetAncestor(parentIndex); } if (false && extraLiteralsAreZero) { // change: // array[0] = f() // into: // array_item0 = f() // array[0] = Copy(array_item0) IExpression copy = Builder.StaticGenericMethod(new Func <PlaceHolder, PlaceHolder>(Clone.Copy <PlaceHolder>), new Type[] { iaie.GetExpressionType() }, info.clone); IStatement copySt = Builder.AssignStmt(iaie, copy); context.AddStatementAfterCurrent(copySt); } } return(info.clone); } if (isDef) { // do not clone the lhs of an array create assignment. IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>(); if (assignExpr.Expression is IArrayCreateExpression) { return(iaie); } } IVariableDeclaration originalBaseVar = Recognizer.GetVariableDeclaration(iaie); // If the variable is not stochastic, return if (!CodeRecognizer.IsStochastic(context, originalBaseVar)) { return(iaie); } IExpression newExpr = null; IVariableDeclaration baseVar = originalBaseVar; IVariableDeclaration newvd = null; IExpression rhsExpr = null; Containers containers = info.containers; Type tp = iaie.GetExpressionType(); if (tp == null) { Error("Could not determine type of expression: " + iaie); return(iaie); } var stmts = Builder.StmtCollection(); var stmtsAfter = Builder.StmtCollection(); // does the expression have the form array[indices[k]][indices2[k]][indices3[k]]? if (newvd == null && UseGetItems && iaie.Indices.Count == 1) { if (iaie.Target is IArrayIndexerExpression iaie2 && iaie.Indices[0] is IArrayIndexerExpression index3 && index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression innerIndex3 && iaie2.Target is IArrayIndexerExpression iaie3 && iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression index2 && index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression innerIndex2 && innerIndex2.Equals(innerIndex3) && iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression index && index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression innerIndex && innerIndex.Equals(innerIndex2)) { IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target)) { // expression has the form array[indices[k]][indices2[k]][indices3[k]] if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } WarnIfLocal(index.Target, iaie3.Target, iaie); WarnIfLocal(index2.Target, iaie3.Target, iaie); WarnIfLocal(index3.Target, iaie3.Target, iaie); containers = RemoveReferencesTo(containers, innerIndex); IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop); var indices = Recognizer.GetIndices(iaie); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, iaie3); AppendIndexString(sb, iaie2); AppendIndexString(sb, iaie); string name = ToString(iaie3.Target) + sb.ToString(); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged), new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target); context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex); rhsExpr = getItems; } } } // does the expression have the form array[indices[k]][indices2[k]]? if (newvd == null && UseGetItems && iaie.Indices.Count == 1) { if (iaie.Target is IArrayIndexerExpression target && iaie.Indices[0] is IArrayIndexerExpression index2 && index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression innerIndex && target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression index) { IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex); if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) && innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target)) { // expression has the form array[indices[k]][indices2[k]] if (isDef) { Error("fancy indexing not allowed on left hand side"); return(iaie); } var innerLoops = new List <IForStatement>(); innerLoops.Add(innerLoop); var indexTarget = index.Target; var index2Target = index2.Target; // check if the index array is jagged, i.e. array[indices[k][j]] while (indexTarget is IArrayIndexerExpression indexTargetExpr && index2Target is IArrayIndexerExpression index2TargetExpr) { if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression innerIndexTarget && index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression innerIndex2Target) { IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget); if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) && innerIndexTarget.Equals(innerIndex2Target)) { innerLoops.Add(indexTargetLoop); indexTarget = indexTargetExpr.Target; index2Target = index2TargetExpr.Target; } else { break; } } else { break; } } WarnIfLocal(indexTarget, target.Target, iaie); WarnIfLocal(index2Target, target.Target, iaie); innerLoops.Reverse(); var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) }); var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) }); // Build name of replacement variable from index values StringBuilder sb = new StringBuilder("_item"); AppendIndexString(sb, target); AppendIndexString(sb, iaie); string name = ToString(target.Target) + sb.ToString(); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); var indices = Recognizer.GetIndices(iaie); newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices); if (!context.InputAttributes.Has <DerivedVariable>(newvd)) { context.InputAttributes.Set(newvd, new DerivedVariable()); } IExpression getItems; if (innerLoops.Count == 1) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged), new Type[] { tp }, target.Target, indexTarget, index2Target); } else if (innerLoops.Count == 2) { getItems = Builder.StaticGenericMethod(new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged), new Type[] { tp }, target.Target, indexTarget, index2Target); } else { throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}"); } context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, getItems); stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems)); var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i]))); newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices); rhsExpr = getItems; } else if (HasAnyCommonLoops(index, index2)) { Warning($"This model will consume excess memory due to the indexing expression {iaie} since {index} and {index2} have larger depth than the compiler can handle."); } }
/// <summary> /// Add expr to currentBlock.variablesDefined or currentBlock.variablesUsed /// </summary> /// <param name="expr"></param> /// <param name="isDef"></param> /// <param name="conditionContext"></param> protected void ProcessUse(IExpression expr, bool isDef, List <ConditionBinding> conditionContext) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr); if (ivd == null) { return; } if (!CodeRecognizer.IsStochastic(context, ivd)) { return; } if (gateBlockContext.Count == 0) { return; } GateBlock currentBlock = gateBlockContext[gateBlockContext.Count - 1]; GateBlock gateBlockOfVar = context.InputAttributes.Get <GateBlock>(ivd); if (gateBlockOfVar == currentBlock) { return; // local variable of the gateBlock } ExpressionWithBindings eb = new ExpressionWithBindings(); eb.Expression = ReplaceLocalIndices(currentBlock, expr); List <ConditionBinding> bindings = FilterConditionContext(currentBlock, conditionContext); if (bindings.Count > 0) { eb.Bindings.Add(bindings); } //eb.Containers = Containers.InsideOf(context, GetAncestorIndexOfGateBlock(currentBlock)); if (isDef) { ExpressionWithBindings eb2; if (!currentBlock.variablesDefined.TryGetValue(ivd, out eb2)) { currentBlock.variablesDefined[ivd] = eb; } else { // all definitions of the same variable must have a common parent currentBlock.variablesDefined[ivd] = GetCommonParent(eb, eb2); } } else { List <ExpressionWithBindings> ebs; if (!currentBlock.variablesUsed.TryGetValue(ivd, out ebs)) { ebs = new List <ExpressionWithBindings>(); ebs.Add(eb); currentBlock.variablesUsed[ivd] = ebs; } else { // collect all uses that overlap with eb, and replace with their common parent List <ExpressionWithBindings> notOverlapping = new List <ExpressionWithBindings>(); while (true) { foreach (ExpressionWithBindings eb2 in ebs) { ExpressionWithBindings parent = GetCommonParent(eb, eb2); if (CouldOverlap(eb, eb2)) { eb = parent; } else { notOverlapping.Add(eb2); } } if (notOverlapping.Count == ebs.Count) { break; // nothing overlaps } // eb must have changed, so try again using the new eb ebs.Clear(); ebs.AddRange(notOverlapping); notOverlapping.Clear(); } ebs.Add(eb); currentBlock.variablesUsed[ivd] = ebs; } } }
private void RegisterDepth(IExpression expr, bool isDef) { if (Recognizer.IsBeingIndexed(context)) { return; } int depth = Recognizer.GetIndexingDepth(expr); IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // If not an indexed variable reference, skip it (e.g. an indexed argument reference) if (baseVar == null) { return; } // If the variable is not stochastic, skip it if (!CodeRecognizer.IsStochastic(context, baseVar)) { return; } ChannelInfo ci = context.InputAttributes.Get <ChannelInfo>(baseVar); if (ci != null && ci.IsMarginal) { return; } DepthInfo depthInfo; if (!depthInfos.TryGetValue(baseVar, out depthInfo)) { depthInfo = new DepthInfo(); depthInfos[baseVar] = depthInfo; } if (isDef) { depthInfo.definitionDepth = depth; return; } depthInfo.useCount++; if (depth < depthInfo.minDepth) { depthInfo.minDepth = depth; } int literalIndexingDepth = 0; foreach (var bracket in Recognizer.GetIndices(expr)) { if (!bracket.All(index => index is ILiteralExpression)) { break; } literalIndexingDepth++; } IndexInfo info; if (depthInfo.indexInfoOfDepth.TryGetValue(depth, out info)) { Containers containers = new Containers(context); info.containers = Containers.Intersect(info.containers, containers); info.literalIndexingDepth = System.Math.Min(info.literalIndexingDepth, literalIndexingDepth); } else { info = new IndexInfo(); info.containers = new Containers(context); info.literalIndexingDepth = literalIndexingDepth; depthInfo.indexInfoOfDepth[depth] = info; } }
protected override IStatement ConvertCondition(IConditionStatement ics) { context.SetPrimaryOutput(ics); ConvertExpression(ics.Condition); ConditionBinding binding = GateTransform.GetConditionBinding(ics.Condition, context, out IForStatement loop); IExpression caseValue = binding.rhs; if (!GateTransform.IsLiteralOrLoopVar(context, caseValue, out loop)) { Error("If statement condition must compare to a literal or loop counter, was: " + ics.Condition); return(ics); } bool isStochastic = CodeRecognizer.IsStochastic(context, binding.lhs); IExpression gateBlockKey; if (isStochastic) { gateBlockKey = binding.lhs; } else { // definitions must not be unified across deterministic gate conditions gateBlockKey = binding.GetExpression(); } GateBlock gateBlock = null; Set <ConditionBinding> bindings = ConditionBinding.Copy(conditionContext); Dictionary <IExpression, GateBlock> blockMap; if (!gateBlocks.TryGetValue(bindings, out blockMap)) { // first time seeing these bindings blockMap = new Dictionary <IExpression, GateBlock>(); gateBlocks[bindings] = blockMap; } if (!blockMap.TryGetValue(gateBlockKey, out gateBlock)) { // first time seeing this lhs gateBlock = new GateBlock(); blockMap[gateBlockKey] = gateBlock; } if (gateBlock.hasLoopCaseValue && loop == null) { Error("Cannot compare " + binding.lhs + " to a literal, since it was previously compared to a loop counter. Put this test inside the loop."); } if (!gateBlock.hasLoopCaseValue && gateBlock.caseValues.Count > 0 && loop != null) { Error("Cannot compare " + binding.lhs + " to a loop counter, since it was previously compared to a literal. Put the literal case inside the loop."); } gateBlock.caseValues.Add(caseValue); if (loop != null) { gateBlock.hasLoopCaseValue = true; } gateBlockContext.Add(gateBlock); context.OutputAttributes.Set(ics, gateBlock); int startIndex = conditionContext.Count; conditionContext.Add(binding); ConvertBlock(ics.Then); if (ics.Else != null) { conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); binding = binding.FlipCondition(); conditionContext.Add(binding); ConvertBlock(ics.Else); } conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); gateBlockContext.RemoveAt(gateBlockContext.Count - 1); // remove any uses that match a def //RemoveUsesOfDefs(gateBlock); if (gateBlockContext.Count > 0) { GateBlock currentBlock = gateBlockContext[gateBlockContext.Count - 1]; // all variables defined/used in the inner block must be processed by the outer block foreach (ExpressionWithBindings eb in gateBlock.variablesDefined.Values) { if (eb.Bindings.Count > 0) { foreach (List <ConditionBinding> binding2 in eb.Bindings) { ProcessUse(eb.Expression, true, Union(conditionContext, binding2)); } } else { ProcessUse(eb.Expression, true, conditionContext); } } foreach (List <ExpressionWithBindings> ebs in gateBlock.variablesUsed.Values) { foreach (ExpressionWithBindings eb in ebs) { if (eb.Bindings.Count > 0) { foreach (ICollection <ConditionBinding> binding2 in eb.Bindings) { ProcessUse(eb.Expression, false, Union(conditionContext, binding2)); } } else { ProcessUse(eb.Expression, false, conditionContext); } } } } return(ics); }
protected void ProcessDefinition(IExpression expr, IVariableDeclaration targetVar, bool isLhs) { bool targetIsPointMass = false; if (expr is IMethodInvokeExpression imie) { // TODO: consider using a method attribute for this if (Recognizer.IsStaticGenericMethod(imie, new Models.FuncOut <PlaceHolder, PlaceHolder, PlaceHolder>(Clone.VariablePoint)) ) { targetIsPointMass = true; } else { FactorManager.FactorInfo info = CodeRecognizer.GetFactorInfo(context, imie); targetIsPointMass = info.IsDeterministicFactor && ( (info.ReturnedInAllElementsParameterIndex != -1 && ArgumentIsPointMass(imie.Arguments[info.ReturnedInAllElementsParameterIndex])) || imie.Arguments.All(ArgumentIsPointMass) ); } if (targetIsPointMass) { // do this immediately so all uses are updated if (!context.InputAttributes.Has <ForwardPointMass>(targetVar)) { context.OutputAttributes.Set(targetVar, new ForwardPointMass()); } // the rest is done later List <IMethodInvokeExpression> list; if (!variablesDefinedPointMass.TryGetValue(targetVar, out list)) { list = new List <IMethodInvokeExpression>(); variablesDefinedPointMass.Add(targetVar, list); } // this code needs to be synchronized with MessageTransform.ConvertMethodInvoke if (Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, int, PlaceHolder[]>(Clone.Replicate)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItems)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItems)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <PlaceHolder>, IReadOnlyList <IReadOnlyList <IReadOnlyList <int> > >, PlaceHolder[][][]>(Collection.GetDeepJaggedItems)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromJagged)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <IReadOnlyList <PlaceHolder> > >, IReadOnlyList <int>, IReadOnlyList <int>, IReadOnlyList <int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged)) || Recognizer.IsStaticGenericMethod(imie, new Func <IReadOnlyList <IReadOnlyList <PlaceHolder> >, IReadOnlyList <IReadOnlyList <int> >, IReadOnlyList <IReadOnlyList <int> >, PlaceHolder[][]>(Collection.GetJaggedItemsFromJagged)) ) { list.Add(imie); } } } if (!targetIsPointMass && !(expr is IArrayCreateExpression)) { variablesDefinedNonPointMass.Add(targetVar); if (variablesDefinedPointMass.ContainsKey(targetVar)) { variablesDefinedPointMass.Remove(targetVar); context.OutputAttributes.Remove <ForwardPointMass>(targetVar); } } bool ArgumentIsPointMass(IExpression arg) { bool IsOut = (arg is IAddressOutExpression); if (CodeRecognizer.IsStochastic(context, arg) && !IsOut) { IVariableDeclaration argVar = Recognizer.GetVariableDeclaration(arg); return((argVar != null) && context.InputAttributes.Has <ForwardPointMass>(argVar)); } else { return(true); } } }
/// <summary> /// Add expr to currentBlock.variablesDefined or currentBlock.variablesUsed /// </summary> /// <param name="expr"></param> /// <param name="isDef"></param> /// <param name="conditionContext"></param> protected void ProcessUse(IExpression expr, bool isDef, List <ConditionBinding> conditionContext) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr); if (ivd == null) { return; } if (!CodeRecognizer.IsStochastic(context, ivd)) { return; } if (gateBlockContext.Count == 0) { return; } GateBlock currentBlock = gateBlockContext[gateBlockContext.Count - 1]; GateBlock gateBlockOfVar = context.InputAttributes.Get <GateBlock>(ivd); if (gateBlockOfVar == currentBlock) { return; // local variable of the gateBlock } ExpressionWithBindings eb = new ExpressionWithBindings(ReplaceLocalIndices(currentBlock, expr), FilterConditionContext(currentBlock)); if (isDef) { if (currentBlock.variablesDefined.TryGetValue(ivd, out ExpressionWithBindings eb2)) { // all definitions of the same variable must have a common parent currentBlock.variablesDefined[ivd] = GetCommonParent(eb, eb2); } else { currentBlock.variablesDefined[ivd] = eb; } } else { if (currentBlock.variablesUsed.TryGetValue(ivd, out List <ExpressionWithBindings> ebs)) { // collect all uses that overlap with eb, and replace with their common parent List <ExpressionWithBindings> notOverlapping = new List <ExpressionWithBindings>(); while (true) { foreach (ExpressionWithBindings eb2 in ebs) { if (CouldOverlap(eb, eb2, ignoreBindings: GateTransform.DeterministicEnterExit)) { eb = GetCommonParent(eb, eb2); } else { notOverlapping.Add(eb2); } } if (notOverlapping.Count == ebs.Count) { break; // nothing overlaps } // eb must have changed, so try again using the new eb ebs.Clear(); ebs.AddRange(notOverlapping); notOverlapping.Clear(); } ebs.Add(eb); currentBlock.variablesUsed[ivd] = ebs; } else { currentBlock.variablesUsed[ivd] = new List <ExpressionWithBindings> { eb }; } } List <ConditionBinding> FilterConditionContext(GateBlock gateBlock) { return(conditionContext .Where(binding => !CodeRecognizer.IsStochastic(context, binding.lhs) && !ContainsLocalVars(gateBlock, binding.lhs) && !ContainsLocalVars(gateBlock, binding.rhs)) .ToList()); } }
// Determine the message type of target from the message type of the factor arguments protected void ProcessFactor(IExpression factor, MessageDirection direction) { NodeInfo info = GetNodeInfo(factor); // fill in argumentTypes Dictionary <string, Type> argumentTypes = new Dictionary <string, Type>(); Dictionary <string, IExpression> arguments = new Dictionary <string, IExpression>(); for (int i = 0; i < info.info.ParameterNames.Count; i++) { string parameterName = info.info.ParameterNames[i]; // Create message info. 'isForward' says whether the message // out is in the forward or backward direction bool isChild = info.isReturnOrOut[i]; IExpression arg = info.arguments[i]; bool isConstant = !CodeRecognizer.IsStochastic(context, arg); if (isConstant) { arguments[parameterName] = arg; Type inwardType = arg.GetExpressionType(); argumentTypes[parameterName] = inwardType; } else if (!isChild) { IExpression msgExpr = GetMessageExpression(arg, fwdMessageVars); if (msgExpr == null) { return; } arguments[parameterName] = msgExpr; Type inwardType = msgExpr.GetExpressionType(); if (inwardType == null) { Error("inferred an incorrect message type for " + arg); return; } argumentTypes[parameterName] = inwardType; } else if (direction == MessageDirection.Backwards) { IExpression msgExpr = GetMessageExpression(arg, bckMessageVars); if (msgExpr == null) { //Console.WriteLine("creating backward message for "+arg); CreateBackwardMessageFromForward(arg, null); msgExpr = GetMessageExpression(arg, bckMessageVars); if (msgExpr == null) { return; } } arguments[parameterName] = msgExpr; Type inwardType = msgExpr.GetExpressionType(); if (inwardType == null) { Error("inferred an incorrect message type for " + arg); return; } argumentTypes[parameterName] = inwardType; } } IAlgorithm alg = algorithm; Algorithm algAttr = context.InputAttributes.Get <Algorithm>(info.imie); if (algAttr != null) { alg = algAttr.algorithm; } List <ICompilerAttribute> factorAttributes = context.InputAttributes.GetAll <ICompilerAttribute>(info.imie); string methodSuffix = alg.GetOperatorMethodSuffix(factorAttributes); // infer types of children for (int i = 0; i < info.info.ParameterNames.Count; i++) { string parameterName = info.info.ParameterNames[i]; bool isChild = info.isReturnOrOut[i]; if (isChild != (direction == MessageDirection.Forwards)) { continue; } IExpression target = info.arguments[i]; bool isConstant = !CodeRecognizer.IsStochastic(context, target); if (isConstant) { continue; } IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd == null) { continue; } Type targetType = null; MessageFcnInfo fcninfo = null; if (direction == MessageDirection.Forwards) { try { fcninfo = GetMessageFcnInfo(info.info, "Init", parameterName, argumentTypes); } catch (Exception) { try { fcninfo = GetMessageFcnInfo(info.info, methodSuffix + "Init", parameterName, argumentTypes); } catch (Exception ex) { //Error("could not determine message type of "+ivd.Name, ex); try { fcninfo = GetMessageFcnInfo(info.info, methodSuffix, parameterName, argumentTypes); if (fcninfo.PassResult) { throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it takes a result parameter. Please provide a separate Init method."); } if (fcninfo.PassResultIndex) { throw new MissingMethodException(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it takes a resultIndex parameter. Please provide a separate Init method."); } } catch (Exception ex2) { if (direction == MessageDirection.Forwards) { Error("could not determine " + direction + " message type of " + ivd.Name + ": " + ex.Message, ex2); continue; } fcninfo = null; } } } if (fcninfo != null) { targetType = fcninfo.Method.ReturnType; if (targetType.IsGenericParameter) { if (direction == MessageDirection.Forwards) { Error("could not determine " + direction + " message type of " + ivd.Name + " in " + StringUtil.MethodFullNameToString(fcninfo.Method)); continue; } fcninfo = null; } } if (fcninfo != null) { VariableInformation vi = VariableInformation.GetVariableInformation(context, ivd); try { targetType = MessageTransform.GetDistributionType(ivd.VariableType.DotNetType, target.GetExpressionType(), targetType, true); } catch (Exception ex) { if (direction == MessageDirection.Forwards) { Error("could not determine " + direction + " message type of " + ivd.Name, ex); continue; } fcninfo = null; } } } Dictionary <IVariableDeclaration, IVariableDeclaration> messageVars = (direction == MessageDirection.Forwards) ? fwdMessageVars : bckMessageVars; if (fcninfo != null) { string name = ivd.Name + (direction == MessageDirection.Forwards ? "_F" : "_B"); IVariableDeclaration msgVar; if (!messageVars.TryGetValue(ivd, out msgVar)) { msgVar = Builder.VarDecl(name, targetType); } if (true) { // construct the init expression List <IExpression> args = new List <IExpression>(); ParameterInfo[] parameters = fcninfo.Method.GetParameters(); foreach (ParameterInfo parameter in parameters) { string argName = parameter.Name; if (IsFactoryType(parameter.ParameterType)) { IVariableDeclaration factoryVar = GetFactoryVariable(parameter.ParameterType); args.Add(Builder.VarRefExpr(factoryVar)); } else { FactorEdge factorEdge = fcninfo.factorEdgeOfParameter[parameter.Name]; string factorParameterName = factorEdge.ParameterName; bool isOutgoingMessage = factorEdge.IsOutgoingMessage; if (!arguments.ContainsKey(factorParameterName)) { if (direction == MessageDirection.Forwards) { Error(StringUtil.MethodFullNameToString(fcninfo.Method) + " is not suitable for initialization since it requires '" + parameter.Name + "'. Please provide a separate Init method."); } fcninfo = null; break; } IExpression arg = arguments[factorParameterName]; args.Add(arg); } } if (fcninfo != null) { IMethodInvokeExpression imie = Builder.StaticMethod(fcninfo.Method, args.ToArray()); //IExpression initExpr = MessageTransform.GetDistributionArrayCreateExpression(ivd.VariableType.DotNetType, target.GetExpressionType(), imie, vi); IExpression initExpr = imie; KeyValuePair <IVariableDeclaration, IExpression> key = new KeyValuePair <IVariableDeclaration, IExpression>(msgVar, factor); messageInitExprs[key] = initExpr; } } if (fcninfo != null) { messageVars[ivd] = msgVar; } } if (fcninfo == null) { if (direction == MessageDirection.Forwards) { continue; } //Console.WriteLine("creating backward message for "+target); CreateBackwardMessageFromForward(target, factor); } IExpression msgExpr = GetMessageExpression(target, messageVars); arguments[parameterName] = msgExpr; Type inwardType = msgExpr.GetExpressionType(); argumentTypes[parameterName] = inwardType; } }
#pragma warning restore 162 #endif private void PostProcess() { // create a dependency graph between MethodInvokes Dictionary <IVariableDeclaration, List <int> > mutationsOfVariable = new Dictionary <IVariableDeclaration, List <NodeIndex> >(); IndexedGraph g = new IndexedGraph(factorExprs.Count); foreach (NodeIndex node in g.Nodes) { IExpression factor = factorExprs[node]; NodeInfo info = GetNodeInfo(factor); for (int i = 0; i < info.arguments.Count; i++) { if (info.isReturnOrOut[i]) { // this is a mutation. add to mutationsOfVariable. IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(info.arguments[i]); if (ivd != null && CodeRecognizer.IsStochastic(context, ivd)) { List <int> nodes; if (!mutationsOfVariable.TryGetValue(ivd, out nodes)) { nodes = new List <NodeIndex>(); mutationsOfVariable[ivd] = nodes; } nodes.Add(node); } } } } foreach (NodeIndex node in g.Nodes) { IExpression factor = factorExprs[node]; NodeInfo info = GetNodeInfo(factor); for (int i = 0; i < info.arguments.Count; i++) { if (!info.isReturnOrOut[i]) { // not a mutation. create a dependency on all mutations. IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(info.arguments[i]); if (ivd != null && CodeRecognizer.IsStochastic(context, ivd)) { foreach (NodeIndex source in mutationsOfVariable[ivd]) { g.AddEdge(source, node); } } } } } List <NodeIndex> topo_nodes = new List <NodeIndex>(); DepthFirstSearch <NodeIndex> dfs = new DepthFirstSearch <NodeIndex>(g.SourcesOf, g); dfs.FinishNode += delegate(NodeIndex node) { IExpression factor = factorExprs[node]; ProcessFactor(factor, MessageDirection.Forwards); topo_nodes.Add(node); }; // process nodes forward dfs.SearchFrom(g.Nodes); // process nodes backward for (int i = topo_nodes.Count - 1; i >= 0; i--) { NodeIndex node = topo_nodes[i]; IExpression factor = factorExprs[node]; ProcessFactor(factor, MessageDirection.Backwards); } }
protected IExpression ConvertWithReplication(IExpression expr) { IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // Check if this is an index local variable if (baseVar == null) { return(expr); } // Check if the variable is stochastic if (!CodeRecognizer.IsStochastic(context, baseVar)) { return(expr); } if (cutVariables.Contains(baseVar)) { return(expr); } // Get the repeat context for this variable RepeatContext lc = context.InputAttributes.Get <RepeatContext>(baseVar); if (lc == null) { Error("Repeat context not found for '" + baseVar.Name + "'."); return(expr); } // Get the reference loop context for this expression var rlc = lc.GetReferenceRepeatContext(context); // If the reference is in the same loop context as the declaration, do nothing. if (rlc.repeats.Count == 0) { return(expr); } // Determine if this expression is being defined (is on the LHS of an assignment) bool isDef = Recognizer.IsBeingMutated(context, expr); Containers containers = context.InputAttributes.Get <Containers>(baseVar); for (int currentRepeat = 0; currentRepeat < rlc.repeatCounts.Count; currentRepeat++) { IExpression repeatCount = rlc.repeatCounts[currentRepeat]; IRepeatStatement repeat = rlc.repeats[currentRepeat]; // must replicate across this loop. if (isDef) { Error("Cannot re-define a variable in a repeat block."); continue; } // are we replicating the argument of Gate.Cases? IMethodInvokeExpression imie = context.FindAncestor <IMethodInvokeExpression>(); if (imie != null) { if (Recognizer.IsStaticMethod(imie, typeof(Gate), "Cases")) { Error("'if(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'"); } if (Recognizer.IsStaticMethod(imie, typeof(Gate), "CasesInt")) { Error("'case(" + expr + ")' or 'switch(" + expr + ")' should be placed outside 'repeat(" + repeatCount + ")'"); } } VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); IList <IStatement> stmts = Builder.StmtCollection(); List <IList <IExpression> > inds = Recognizer.GetIndices(expr); IVariableDeclaration repVar = varInfo.DeriveIndexedVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rpt"), inds); if (!context.InputAttributes.Has <DerivedVariable>(repVar)) { context.OutputAttributes.Set(repVar, new DerivedVariable()); } if (context.InputAttributes.Has <ChannelInfo>(baseVar)) { VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar); ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo); ci.decl = repVar; context.OutputAttributes.Set(repVar, ci); } // set the RepeatContext of repVar to include all repeats up to this one (so that it doesn't get Entered again) List <IRepeatStatement> repeats = new List <IRepeatStatement>(lc.repeats); for (int i = 0; i <= currentRepeat; i++) { repeats.Add(rlc.repeats[i]); } context.OutputAttributes.Remove <RepeatContext>(repVar); context.OutputAttributes.Set(repVar, new RepeatContext(repeats)); // Create replicate factor Type returnType = Builder.ToType(repVar.VariableType); IMethodInvokeExpression powerPlateMethod = Builder.StaticGenericMethod( new Func <PlaceHolder, double, PlaceHolder>(PowerPlate.Enter), new Type[] { returnType }, expr, repeatCount); IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), powerPlateMethod); // Copy attributes across from variable to replication expression context.InputAttributes.CopyObjectAttributesTo <Algorithm>(baseVar, context.OutputAttributes, powerPlateMethod); context.InputAttributes.CopyObjectAttributesTo <DivideMessages>(baseVar, context.OutputAttributes, powerPlateMethod); context.InputAttributes.CopyObjectAttributesTo <GivePriorityTo>(baseVar, context.OutputAttributes, powerPlateMethod); stmts.Add(Builder.ExprStatement(assignExpression)); // add any containers missing from context. containers = new Containers(context); // remove inner repeats for (int i = currentRepeat + 1; i < rlc.repeatCounts.Count; i++) { containers = containers.RemoveOneRepeat(rlc.repeats[i]); } context.OutputAttributes.Set(repVar, containers); containers = containers.RemoveOneRepeat(repeat); containers = Containers.RemoveUnusedLoops(containers, context, powerPlateMethod); if (context.InputAttributes.Has <DoNotSendEvidence>(baseVar)) { containers = Containers.RemoveStochasticConditionals(containers, context); } //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count); //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; } int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.inputs); // must convert the output since it may contain 'if' conditions context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true); baseVar = repVar; expr = Builder.VarRefExpr(repVar); } return(expr); }
// copied from ReplicationTransform /// <summary> /// Returns true if the supplied expression is stochastic. /// </summary> /// <param name="expr"></param> /// <returns></returns> protected bool IsStochasticVariableReference(IExpression expr) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(expr); return((ivd != null) && CodeRecognizer.IsStochastic(context, ivd)); }
protected override IExpression ConvertAssign(IAssignExpression iae) { foreach (IStatement stmt in context.FindAncestors <IStatement>()) { // an initializer statement may perform a copy, but it is not valid to replace the lhs // in that case. if (context.InputAttributes.Has <Initializer>(stmt)) { return(iae); } } // Look for assignments where the right hand side is a SetTo call if (iae.Expression is IMethodInvokeExpression imie) { bool isCopy = Recognizer.IsStaticGenericMethod(imie, new Func <PlaceHolder, PlaceHolder>(Clone.Copy)); bool isSetTo = Recognizer.IsStaticGenericMethod(imie, typeof(ArrayHelper), "SetTo"); bool isSetAllElementsTo = Recognizer.IsStaticGenericMethod(imie, typeof(ArrayHelper), "SetAllElementsTo"); bool isGetItemsPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsPointOp <>), "ItemsAverageConditional"); bool isGetJaggedItemsPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetJaggedItemsPointOp <>), "ItemsAverageConditional"); bool isGetDeepJaggedItemsPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetDeepJaggedItemsPointOp <>), "ItemsAverageConditional"); bool isGetItemsFromJaggedPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsFromJaggedPointOp <>), "ItemsAverageConditional"); bool isGetItemsFromDeepJaggedPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsFromDeepJaggedPointOp <>), "ItemsAverageConditional"); bool isGetJaggedItemsFromJaggedPoint = Recognizer.IsStaticGenericMethod(imie, typeof(GetJaggedItemsFromJaggedPointOp <>), "ItemsAverageConditional"); if (isCopy || isSetTo || isSetAllElementsTo || isGetItemsPoint || isGetJaggedItemsPoint || isGetDeepJaggedItemsPoint || isGetJaggedItemsFromJaggedPoint || isGetItemsFromJaggedPoint || isGetItemsFromDeepJaggedPoint) { IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target); // Find the condition context var ifs = context.FindAncestors <IConditionStatement>(); var condContext = new List <IConditionStatement>(); foreach (var ifSt in ifs) { if (!CodeRecognizer.IsStochastic(context, ifSt.Condition)) { condContext.Add(ifSt); } } var copyAttr = context.InputAttributes.GetOrCreate <CopyOfAttribute>(ivd, () => new CopyOfAttribute()); IExpression rhs; if (isSetTo || isSetAllElementsTo) { // Mark as copy of the second argument rhs = imie.Arguments[1]; } else { // Mark as copy of the first argument rhs = imie.Arguments[0]; } InitialiseTo init = context.InputAttributes.Get <InitialiseTo>(ivd); if (init != null) { IVariableDeclaration ivdRhs = Recognizer.GetVariableDeclaration(rhs); InitialiseTo initRhs = (ivdRhs == null) ? null : context.InputAttributes.Get <InitialiseTo>(ivdRhs); if (initRhs == null || !initRhs.initialMessagesExpression.Equals(init.initialMessagesExpression)) { // Do not replace a variable with a unique initialiser return(iae); } } var initBack = context.InputAttributes.Get <InitialiseBackwardTo>(ivd); if (initBack != null && !(initBack.initialMessagesExpression is IArrayCreateExpression)) { IVariableDeclaration ivdRhs = Recognizer.GetVariableDeclaration(rhs); InitialiseBackwardTo initRhs = (ivdRhs == null) ? null : context.InputAttributes.Get <InitialiseBackwardTo>(ivdRhs); if (initRhs == null || !initRhs.initialMessagesExpression.Equals(init.initialMessagesExpression)) { // Do not replace a variable with a unique initialiser return(iae); } } if (isCopy || isSetTo) { RemoveMatchingSuffixes(iae.Target, rhs, condContext, out IExpression lhsPrefix, out IExpression rhsPrefix); copyAttr.copyMap[lhsPrefix] = new CopyOfAttribute.CopyContext { Expression = rhsPrefix, ConditionContext = condContext }; } else if (isSetAllElementsTo) { copyAttr.copiedInEveryElementMap[iae.Target] = new CopyOfAttribute.CopyContext { Expression = rhs, ConditionContext = condContext }; } else if (isGetItemsPoint || isGetJaggedItemsPoint || isGetDeepJaggedItemsPoint || isGetItemsFromJaggedPoint || isGetItemsFromDeepJaggedPoint || isGetJaggedItemsFromJaggedPoint) { var target = ((IArrayIndexerExpression)iae.Target).Target; int inputDepth = imie.Arguments.Count - 3; List <IExpression> indexExprs = new List <IExpression>(); for (int i = 0; i < inputDepth; i++) { indexExprs.Add(imie.Arguments[1 + i]); } int outputDepth; if (isGetDeepJaggedItemsPoint) { outputDepth = 3; } else if (isGetJaggedItemsPoint || isGetJaggedItemsFromJaggedPoint) { outputDepth = 2; } else { outputDepth = 1; } copyAttr.copyAtIndexMap[target] = new CopyOfAttribute.CopyContext2 { Depth = outputDepth, ConditionContext = condContext, ExpressionAtIndex = (lhsIndices) => { return(Builder.JaggedArrayIndex(rhs, indexExprs.ListSelect(indexExpr => new[] { Builder.JaggedArrayIndex(indexExpr, lhsIndices) }))); } }; } else { throw new NotImplementedException(); } } } return(iae); }
protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie) { bool isDef = Recognizer.IsBeingMutated(context, iaie); if (isDef) { // do not clone the lhs of an array create assignment. IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>(); if (assignExpr.Expression is IArrayCreateExpression) { return(iaie); } } base.ConvertArrayIndexer(iaie); IndexInfo info; // TODO: Instead of storing an IndexInfo for each distinct expression, we should try to unify expressions, as in GateAnalysisTransform. // For example, we could unify a[0,i] and a[0,0] and use the same clone array for both. if (indexInfoOf.TryGetValue(iaie, out info)) { Containers containers = new Containers(context); if (info.bindings.Count > 0) { List <ConditionBinding> bindings = GetBindings(context, containers.inputs); if (bindings.Count == 0) { info.bindings.Clear(); } else { info.bindings.Add(bindings); } } info.containers = Containers.Intersect(info.containers, containers); info.count++; if (isDef) { info.IsAssignedTo = true; } return(iaie); } CheckIndicesAreNotStochastic(iaie.Indices); IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(iaie); // If not an indexed variable reference, skip it (e.g. an indexed argument reference) if (baseVar == null) { return(iaie); } // If the variable is not stochastic, skip it if (!CodeRecognizer.IsStochastic(context, baseVar)) { return(iaie); } // If the indices are all loop variables, skip it var indices = Recognizer.GetIndices(iaie); bool allLoopIndices = indices.All(bracket => bracket.All(indexExpr => { if (indexExpr is IVariableReferenceExpression) { IVariableReferenceExpression ivre = (IVariableReferenceExpression)indexExpr; return(Recognizer.GetLoopForVariable(context, ivre) != null); } else { return(false); } })); if (allLoopIndices) { return(iaie); } info = new IndexInfo(); info.containers = new Containers(context); List <ConditionBinding> bindings2 = GetBindings(context, info.containers.inputs); if (bindings2.Count > 0) { info.bindings.Add(bindings2); } info.count = 1; info.IsAssignedTo = isDef; indexInfoOf[iaie] = info; return(iaie); }