protected override IStatement ConvertFor(IForStatement ifs) { IExpression sizeExpr = Recognizer.LoopSizeExpression(ifs); if (sizeExpr is ILiteralExpression) { int size = (int)((ILiteralExpression)sizeExpr).Value; if (size < 20) { ConditionBinding binding = GetInitializerBinding(ifs); if (binding.rhs is ILiteralExpression) { int start = (int)((ILiteralExpression)binding.rhs).Value; for (int i = start; i < size; i++) { iterationContext.Add(new ConditionBinding(binding.lhs, Builder.LiteralExpr(i))); IBlockStatement body = ConvertBlock(ifs.Body); context.AddStatementsBeforeCurrent(body.Statements); iterationContext.RemoveAt(iterationContext.Count - 1); } return(null); } } } return(base.ConvertFor(ifs)); }
protected override IStatement ConvertCondition(IConditionStatement ics) { IConditionStatement cs = Builder.CondStmt(); cs.Condition = ConvertExpression(ics.Condition); if (cs.Condition is ILiteralExpression) { bool value = (bool)((ILiteralExpression)cs.Condition).Value; if (value) { if (ics.Then != null) { foreach (IStatement st in ics.Then.Statements) { IStatement ist = ConvertStatement(st); if (ist != null) { context.AddStatementBeforeCurrent(ist); } } } } else { if (ics.Else != null) { foreach (IStatement st in ics.Else.Statements) { IStatement ist = ConvertStatement(st); if (ist != null) { context.AddStatementBeforeCurrent(ist); } } } } return(null); } context.SetPrimaryOutput(cs); IForStatement loop; ConditionBinding binding = GateTransform.GetConditionBinding(cs.Condition, context, out loop); int startIndex = conditionContext.Count; conditionContext.Add(binding); cs.Then = ConvertBlock(ics.Then); if (ics.Else != null) { conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); binding = binding.FlipCondition(); conditionContext.Add(binding); cs.Else = ConvertBlock(ics.Else); } conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); if (cs.Then.Statements.Count == 0 && (cs.Else == null || cs.Else.Statements.Count == 0)) { return(null); } return(cs); }
public override bool Equals(object obj) { ConditionBinding that = obj as ConditionBinding; if (that == null) { return(false); } return(lhs.Equals(that.lhs) && rhs.Equals(that.rhs)); }
internal static BindingSet FromBindings(IEnumerable <ICollection <ConditionBinding> > bindingSet) { BindingSet result = new BindingSet(); foreach (ICollection <ConditionBinding> bindings in bindingSet) { // only boolean bindings are returned since we know there must be two cases. // integer bindings could have only one case. var dict = ConditionBinding.ToDictionary(bindings, true); if (dict.Count > 0) { result.set.Add(dict); } } return(result); }
internal static List <ConditionBinding> GetBindings(BasicTransformContext context, IEnumerable <IStatement> containers) { List <ConditionBinding> bindings = new List <ConditionBinding>(); foreach (IStatement st in containers) { if (st is IConditionStatement) { IConditionStatement ics = (IConditionStatement)st; if (!CodeRecognizer.IsStochastic(context, ics.Condition)) { ConditionBinding binding = new ConditionBinding(ics.Condition); bindings.Add(binding); } } } return(bindings); }
protected void RegisterDefinition(IVariableDeclaration ivd) { Set <ICollection <ConditionBinding> > defBindings; if (!definitionBindings.TryGetValue(ivd, out defBindings)) { defBindings = new Set <ICollection <ConditionBinding> >(); definitionBindings[ivd] = defBindings; } else if (defBindings.Count == 0) { return; } if (conditionContext.Count > 0) { Set <ConditionBinding> bindings = ConditionBinding.Copy(conditionContext); bindings.Remove(declarationBindings[ivd]); if (bindings.Count > 0) { defBindings.Add(bindings); } } }
/// <summary> /// Assignments to non-stochastic non-loop integer variables are added to the conditionContext /// </summary> /// <param name="iae"></param> /// <returns></returns> protected override IExpression ConvertAssign(IAssignExpression iae) { IExpression expr = ConvertExpression(iae.Expression); IExpression target = ConvertExpression(iae.Target); IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(target); if (ivd != null) { var varInfo = VariableInformation.GetVariableInformation(context, ivd); if (!varInfo.IsStochastic && varInfo.varType.Equals(typeof(int)) && Recognizer.GetLoopForVariable(context, ivd) == null) { // add the assignment as a binding if (target is IVariableDeclarationExpression ivde) { target = Builder.VarRefExpr(ivde.Variable); } ConditionBinding binding = new ConditionBinding(target, expr); conditionContext.Add(binding); // when current lexical scope ends, remove this binding? // no, because locals aren't correctly scoped yet } } return(iae); }
/// <summary> /// Analyse the condition body using an augmented conditionContext /// </summary> /// <param name="ics"></param> /// <returns></returns> protected override IStatement ConvertCondition(IConditionStatement ics) { if (CodeRecognizer.IsStochastic(context, ics.Condition)) { return(base.ConvertCondition(ics)); } // ics.Condition is not stochastic context.SetPrimaryOutput(ics); ConvertExpression(ics.Condition); ConditionBinding binding = new ConditionBinding(ics.Condition); int startIndex = conditionContext.Count; conditionContext.Add(binding); ConvertBlock(ics.Then); if (ics.Else != null) { conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); binding = binding.FlipCondition(); conditionContext.Add(binding); ConvertBlock(ics.Else); } conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); return(ics); }
protected IExpression ConvertWithReplication(IExpression expr) { IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(expr); // Check if this is an index local variable if (baseVar == null) return expr; // Check if the variable is stochastic if (!CodeRecognizer.IsStochastic(context, baseVar)) return expr; // Get the loop context for this variable LoopContext lc = context.InputAttributes.Get<LoopContext>(baseVar); if (lc == null) { Error("Loop context not found for '" + baseVar.Name + "'."); return expr; } // Get the reference loop context for this expression RefLoopContext rlc = lc.GetReferenceLoopContext(context); // If the reference is in the same loop context as the declaration, do nothing. if (rlc.loops.Count == 0) return expr; // the set of loop variables that are constant wrt the expr Set<IVariableDeclaration> constantLoopVars = new Set<IVariableDeclaration>(); constantLoopVars.AddRange(lc.loopVariables); // collect set of all loop variable indices in the expression Set<int> embeddedLoopIndices = new Set<int>(); List<IList<IExpression>> brackets = Recognizer.GetIndices(expr); foreach (IList<IExpression> bracket in brackets) { foreach (IExpression index in bracket) { IExpression indExpr = index; if (indExpr is IBinaryExpression ibe) { indExpr = ibe.Left; } IVariableDeclaration indVar = Recognizer.GetVariableDeclaration(indExpr); if (indVar != null) { if (!constantLoopVars.Contains(indVar)) { int loopIndex = rlc.loopVariables.IndexOf(indVar); if (loopIndex != -1) { // indVar is a loop variable constantLoopVars.Add(rlc.loopVariables[loopIndex]); } else { // indVar is not a loop variable LoopContext lc2 = context.InputAttributes.Get<LoopContext>(indVar); foreach (var ivd in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd} is not in {rlc} for expression {expr}"); } } } } } else { foreach(var ivd in Recognizer.GetVariables(indExpr)) { if (!constantLoopVars.Contains(ivd)) { // copied from above LoopContext lc2 = context.InputAttributes.Get<LoopContext>(ivd); foreach (var ivd2 in lc2.loopVariables) { if (!constantLoopVars.Contains(ivd2)) { int loopIndex2 = rlc.loopVariables.IndexOf(ivd2); if (loopIndex2 != -1) embeddedLoopIndices.Add(loopIndex2); else Error($"Index {ivd2} is not in {rlc} for expression {expr}"); } } } } } } } // Find loop variables that must be constant due to condition statements. List<IStatement> ancestors = context.FindAncestors<IStatement>(); foreach (IStatement ancestor in ancestors) { if (!(ancestor is IConditionStatement ics)) continue; ConditionBinding binding = new ConditionBinding(ics.Condition); IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(binding.lhs); IVariableDeclaration ivd2 = Recognizer.GetVariableDeclaration(binding.rhs); int index = rlc.loopVariables.IndexOf(ivd); if (index >= 0 && IsConstantWrtLoops(ivd2, constantLoopVars)) { constantLoopVars.Add(ivd); continue; } int index2 = rlc.loopVariables.IndexOf(ivd2); if (index2 >= 0 && IsConstantWrtLoops(ivd, constantLoopVars)) { constantLoopVars.Add(ivd2); continue; } } // Determine if this expression is being defined (is on the LHS of an assignment) bool isDef = Recognizer.IsBeingMutated(context, expr); Containers containers = context.InputAttributes.Get<Containers>(baseVar); IExpression originalExpr = expr; for (int currentLoop = 0; currentLoop < rlc.loopVariables.Count; currentLoop++) { IVariableDeclaration loopVar = rlc.loopVariables[currentLoop]; if (constantLoopVars.Contains(loopVar)) continue; IForStatement loop = rlc.loops[currentLoop]; // must replicate across this loop. if (isDef) { Error("Cannot re-define a variable in a loop. Variables on the left hand side of an assignment must be indexed by all containing loops."); continue; } if (embeddedLoopIndices.Contains(currentLoop)) { string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays."; Warning(string.Format(warningText, originalExpr, loopVar.Name)); } // split expr into a target and extra indices, where target will be replicated and extra indices will be added later var extraIndices = new List<IEnumerable<IExpression>>(); AddUnreplicatedIndices(rlc.loops[currentLoop], expr, extraIndices, out IExpression exprToReplicate); VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar); IExpression loopSize = Recognizer.LoopSizeExpression(loop); IList<IStatement> stmts = Builder.StmtCollection(); List<IList<IExpression>> inds = Recognizer.GetIndices(exprToReplicate); IVariableDeclaration newIndexVar = loopVar; // if loopVar is already an indexVar of varInfo, create a new variable if (varInfo.HasIndexVar(loopVar)) { newIndexVar = VariableInformation.GenerateLoopVar(context, "_a"); context.InputAttributes.CopyObjectAttributesTo(loopVar, context.OutputAttributes, newIndexVar); } IVariableDeclaration repVar = varInfo.DeriveArrayVariable(stmts, context, VariableInformation.GenerateName(context, varInfo.Name + "_rep"), loopSize, newIndexVar, inds, useArrays: true); if (!context.InputAttributes.Has<DerivedVariable>(repVar)) context.OutputAttributes.Set(repVar, new DerivedVariable()); if (context.InputAttributes.Has<ChannelInfo>(baseVar)) { VariableInformation repVarInfo = VariableInformation.GetVariableInformation(context, repVar); ChannelInfo ci = ChannelInfo.UseChannel(repVarInfo); ci.decl = repVar; context.OutputAttributes.Set(repVar, ci); } // Create replicate factor Type returnType = Builder.ToType(repVar.VariableType); IMethodInvokeExpression repMethod = Builder.StaticGenericMethod( new Func<PlaceHolder, int, PlaceHolder[]>(Clone.Replicate), new Type[] {returnType.GetElementType()}, exprToReplicate, loopSize); IExpression assignExpression = Builder.AssignExpr(Builder.VarRefExpr(repVar), repMethod); // Copy attributes across from variable to replication expression context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<DivideMessages>(baseVar, context.OutputAttributes, repMethod); context.InputAttributes.CopyObjectAttributesTo<GivePriorityTo>(baseVar, context.OutputAttributes, repMethod); stmts.Add(Builder.ExprStatement(assignExpression)); // add any containers missing from context. containers = new Containers(context); // RemoveUnusedLoops will also remove conditionals involving those loop variables. // TODO: investigate whether removing these conditionals could cause a problem, e.g. when the condition is a conjunction of many terms. containers = Containers.RemoveUnusedLoops(containers, context, repMethod); if (context.InputAttributes.Has<DoNotSendEvidence>(baseVar)) containers = Containers.RemoveStochasticConditionals(containers, context); //Containers shouldBeEmpty = containers.GetContainersNotInContext(context, context.InputStack.Count); //if (shouldBeEmpty.inputs.Count > 0) { Error("Internal: Variable is out of scope"); return expr; } if (containers.Contains(loop)) { Error("Internal: invalid containers for replicating " + baseVar); break; } int ancIndex = containers.GetMatchingAncestorIndex(context); Containers missing = containers.GetContainersNotInContext(context, ancIndex); stmts = Containers.WrapWithContainers(stmts, missing.inputs); context.OutputAttributes.Set(repVar, containers); List<IForStatement> loops = context.FindAncestors<IForStatement>(ancIndex); foreach (IStatement container in missing.inputs) { if (container is IForStatement ifs) loops.Add(ifs); } context.OutputAttributes.Set(repVar, new LoopContext(loops)); // must convert the output since it may contain 'if' conditions context.AddStatementsBeforeAncestorIndex(ancIndex, stmts, true); baseVar = repVar; expr = Builder.ArrayIndex(Builder.VarRefExpr(repVar), Builder.VarRefExpr(loopVar)); expr = Builder.JaggedArrayIndex(expr, extraIndices); } return expr; }
protected override IStatement ConvertCondition(IConditionStatement ics) { context.SetPrimaryOutput(ics); ConvertExpression(ics.Condition); ConditionBinding binding = GateTransform.GetConditionBinding(ics.Condition, context, out IForStatement loop); IExpression caseValue = binding.rhs; if (!GateTransform.IsLiteralOrLoopVar(context, caseValue, out loop)) { Error("If statement condition must compare to a literal or loop counter, was: " + ics.Condition); return(ics); } bool isStochastic = CodeRecognizer.IsStochastic(context, binding.lhs); IExpression gateBlockKey; if (isStochastic) { gateBlockKey = binding.lhs; } else { // definitions must not be unified across deterministic gate conditions gateBlockKey = binding.GetExpression(); } GateBlock gateBlock = null; Set <ConditionBinding> bindings = ConditionBinding.Copy(conditionContext); Dictionary <IExpression, GateBlock> blockMap; if (!gateBlocks.TryGetValue(bindings, out blockMap)) { // first time seeing these bindings blockMap = new Dictionary <IExpression, GateBlock>(); gateBlocks[bindings] = blockMap; } if (!blockMap.TryGetValue(gateBlockKey, out gateBlock)) { // first time seeing this lhs gateBlock = new GateBlock(); blockMap[gateBlockKey] = gateBlock; } if (gateBlock.hasLoopCaseValue && loop == null) { Error("Cannot compare " + binding.lhs + " to a literal, since it was previously compared to a loop counter. Put this test inside the loop."); } if (!gateBlock.hasLoopCaseValue && gateBlock.caseValues.Count > 0 && loop != null) { Error("Cannot compare " + binding.lhs + " to a loop counter, since it was previously compared to a literal. Put the literal case inside the loop."); } gateBlock.caseValues.Add(caseValue); if (loop != null) { gateBlock.hasLoopCaseValue = true; } gateBlockContext.Add(gateBlock); context.OutputAttributes.Set(ics, gateBlock); int startIndex = conditionContext.Count; conditionContext.Add(binding); ConvertBlock(ics.Then); if (ics.Else != null) { conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); binding = binding.FlipCondition(); conditionContext.Add(binding); ConvertBlock(ics.Else); } conditionContext.RemoveRange(startIndex, conditionContext.Count - startIndex); gateBlockContext.RemoveAt(gateBlockContext.Count - 1); // remove any uses that match a def //RemoveUsesOfDefs(gateBlock); if (gateBlockContext.Count > 0) { GateBlock currentBlock = gateBlockContext[gateBlockContext.Count - 1]; // all variables defined/used in the inner block must be processed by the outer block foreach (ExpressionWithBindings eb in gateBlock.variablesDefined.Values) { if (eb.Bindings.Count > 0) { foreach (List <ConditionBinding> binding2 in eb.Bindings) { ProcessUse(eb.Expression, true, Union(conditionContext, binding2)); } } else { ProcessUse(eb.Expression, true, conditionContext); } } foreach (List <ExpressionWithBindings> ebs in gateBlock.variablesUsed.Values) { foreach (ExpressionWithBindings eb in ebs) { if (eb.Bindings.Count > 0) { foreach (ICollection <ConditionBinding> binding2 in eb.Bindings) { ProcessUse(eb.Expression, false, Union(conditionContext, binding2)); } } else { ProcessUse(eb.Expression, false, conditionContext); } } } } return(ics); }
protected void RegisterDeclaration(IVariableDeclaration ivd) { declarationBindings[ivd] = ConditionBinding.Copy(conditionContext); }
protected override IExpression ConvertArrayIndexer(IArrayIndexerExpression iaie) { bool isDef = Recognizer.IsBeingMutated(context, iaie); if (isDef) { // do not clone the lhs of an array create assignment. IAssignExpression assignExpr = context.FindAncestor <IAssignExpression>(); if (assignExpr.Expression is IArrayCreateExpression) { return(iaie); } } base.ConvertArrayIndexer(iaie); IndexInfo info; // TODO: Instead of storing an IndexInfo for each distinct expression, we should try to unify expressions, as in GateAnalysisTransform. // For example, we could unify a[0,i] and a[0,0] and use the same clone array for both. if (indexInfoOf.TryGetValue(iaie, out info)) { Containers containers = new Containers(context); if (info.bindings.Count > 0) { List <ConditionBinding> bindings = GetBindings(context, containers.inputs); if (bindings.Count == 0) { info.bindings.Clear(); } else { info.bindings.Add(bindings); } } info.containers = Containers.Intersect(info.containers, containers); info.count++; if (isDef) { info.IsAssignedTo = true; } return(iaie); } CheckIndicesAreNotStochastic(iaie.Indices); IVariableDeclaration baseVar = Recognizer.GetVariableDeclaration(iaie); // If not an indexed variable reference, skip it (e.g. an indexed argument reference) if (baseVar == null) { return(iaie); } // If the variable is not stochastic, skip it if (!CodeRecognizer.IsStochastic(context, baseVar)) { return(iaie); } // If the indices are all loop variables, skip it var indices = Recognizer.GetIndices(iaie); bool allLoopIndices = indices.All(bracket => bracket.All(indexExpr => { if (indexExpr is IVariableReferenceExpression ivre) { return(Recognizer.GetLoopForVariable(context, ivre) != null); } else { return(false); } })); if (allLoopIndices) { return(iaie); } info = new IndexInfo(); info.containers = new Containers(context); List <ConditionBinding> bindings2 = GetBindings(context, info.containers.inputs); if (bindings2.Count > 0) { info.bindings.Add(bindings2); } info.count = 1; info.IsAssignedTo = isDef; indexInfoOf[iaie] = info; return(iaie); List <ConditionBinding> GetBindings(BasicTransformContext context, IEnumerable <IStatement> containers) { List <ConditionBinding> bindings = new List <ConditionBinding>(); foreach (IStatement st in containers) { if (st is IConditionStatement ics) { if (!CodeRecognizer.IsStochastic(context, ics.Condition)) { ConditionBinding binding = new ConditionBinding(ics.Condition); bindings.Add(binding); } } } return(bindings); } void CheckIndicesAreNotStochastic(IList <IExpression> exprs) { foreach (IExpression index in exprs) { foreach (var ivd in Recognizer.GetVariables(index)) { if (CodeRecognizer.IsStochastic(context, ivd)) { string msg = "Indexing by a random variable '" + ivd.Name + "'. You must wrap this statement with Variable.Switch(" + index + ")"; Error(msg); } } } } }