/// <summary> /// Remove nodes from the cyclic schedule whose result does not reach usedNodes. /// </summary> /// <param name="g"></param> /// <param name="schedule"></param> /// <param name="usedNodes">On entry, the set of nodes whose value is needed at the end of the schedule. /// On exit, the set of nodes whose value is needed at the beginning of the schedule.</param> /// <param name="usedBySelf">On entry, the set of nodes whose value is used only by itself (due to a cyclic dependency). /// On exit, the set of nodes whose value is used only by itself.</param> /// <param name="tailSchedule">On exit, a special schedule to use for the final iteration. Empty if the final schedule is the same as the regular schedule.</param> /// <returns>A new schedule.</returns> /// <remarks> /// We cannot simply do a graph search because we want to compute reachability with respect to the nodes /// that are actually on the schedule. /// </remarks> private static List <NodeIndex> PruneDeadNodesCyclic(DependencyGraph2 g, IList <NodeIndex> schedule, ICollection <NodeIndex> usedNodes, ICollection <NodeIndex> usedBySelf, out List <NodeIndex> tailSchedule) { tailSchedule = new List <NodeIndex>(); List <NodeIndex> lastSchedule = PruneDeadNodes(g, schedule, usedNodes, usedBySelf); Set <NodeIndex> everUsed = new Set <EdgeIndex>(); everUsed.AddRange(usedNodes); // repeat until convergence while (true) { // this prevents usedNodes from getting smaller usedNodes.AddRange(everUsed); List <NodeIndex> newSchedule = PruneDeadNodes(g, schedule, usedNodes, usedBySelf); int usedNodeCount = everUsed.Count; everUsed.AddRange(usedNodes); if (everUsed.Count == usedNodeCount) { // converged // does lastSchedule have a statement not in newSchedule? foreach (NodeIndex node in lastSchedule) { if (!newSchedule.Contains(node)) { tailSchedule.Add(node); } } return(newSchedule); } } }
private void ForEachInitializer(IStatement source, NodeIndex target, DependencyGraph2 g, Action <IStatement> action) { Stack <IStatement> todo = new Stack <IStatement>(); todo.Push(source); while (todo.Count > 0) { IStatement source2 = todo.Pop(); DependencyInformation di2 = context.InputAttributes.Get <DependencyInformation>(source2); if (di2 == null) { context.Error("Dependency information not found for statement: " + source2); continue; } foreach (IStatement init in di2.Overwrites) { int initIndex = g.indexOfNode[init]; if (initIndex < target) { // found a valid initializer action(init); } else { // keep looking backward for a valid initializer todo.Push(init); } } } }
/// <summary> /// Remove nodes from the schedule whose result does not reach usedNodes. /// </summary> /// <param name="g"></param> /// <param name="schedule">An ordered list of nodes.</param> /// <param name="usedNodes">On entry, the set of nodes whose value is needed at the end of the schedule. /// On exit, the set of nodes whose value is needed at the beginning of the schedule.</param> /// <param name="usedBySelf">On entry, the set of nodes whose value is used only by itself (due to a cyclic dependency). /// On exit, the set of nodes whose value is used only by itself.</param> /// <returns>A subset of the schedule, in the same order.</returns> private static List <NodeIndex> PruneDeadNodes(DependencyGraph2 g, IList <NodeIndex> schedule, ICollection <NodeIndex> usedNodes, ICollection <NodeIndex> usedBySelf) { List <NodeIndex> newSchedule = new List <NodeIndex>(); // loop the schedule in reverse order for (int i = schedule.Count - 1; i >= 0; i--) { NodeIndex node = schedule[i]; bool used = usedNodes.Contains(node); if (usedBySelf.Contains(node)) { // if the node is used only by itself (due to cyclic dependency) then consider the node as dead. used = false; // initializers must still be considered used foreach (EdgeIndex edge in g.dependencyGraph.EdgesInto(node)) { NodeIndex source = g.dependencyGraph.SourceOf(edge); if (source == node) { continue; } usedNodes.Add(source); if (source != node && g.nodes[source] == g.nodes[node]) { usedBySelf.Add(source); } } } usedNodes.Remove(node); if (used) { newSchedule.Add(node); foreach (NodeIndex source in g.dependencyGraph.SourcesOf(node)) { if (!usedNodes.Contains(source)) { usedNodes.Add(source); } // a node is added to usedBySelf only if it is used by a copy of itself. // we don't need to have multiple copies of a self-loop, but we do need one instance of it. if (source != node && g.nodes[source] == g.nodes[node]) { usedBySelf.Add(source); } else { usedBySelf.Remove(source); } } } } newSchedule.Reverse(); return(newSchedule); }
private void AddLoopInitializers(IEnumerable <NodeIndex> nodes, ICollection <NodeIndex> usedNodes, Dictionary <NodeIndex, StatementBlock> blockOfNode, DependencyGraph2 g) { foreach (NodeIndex node in nodes) { // when the target is also in a while loop, you need the initializer to precede target's entire loop NodeIndex target = (!(blockOfNode[node] is Loop nodeLoop) || nodeLoop.indices.Count == 0) ? node : nodeLoop.indices[0]; foreach (NodeIndex source in g.dependencyGraph.SourcesOf(node)) { if (source >= node && usedNodes.Contains(source) && blockOfNode[source] is Loop loop) { ForEachInitializer(g.nodes[source], target, g, loop.initializers.Add); } } } }
// for each back edge whose source is in a while loop, add the appropriate initializer of source to the InitializerSet private void AddLoopInitializers(StatementBlock block, ICollection <NodeIndex> usedNodes, Dictionary <NodeIndex, StatementBlock> blockOfNode, DependencyGraph2 g) { if (block is Loop loop) { AddLoopInitializers(loop.tail, usedNodes, blockOfNode, g); AddLoopInitializers(loop.firstIterPostBlock, usedNodes, blockOfNode, g); } AddLoopInitializers(block.indices, usedNodes, blockOfNode, g); }
protected IList <IStatement> Schedule(IList <IStatement> isc) { List <StatementBlock> blocks = new List <StatementBlock>(); StatementBlock currentBlock = new StraightLine(); Dictionary <NodeIndex, StatementBlock> blockOfNode = new Dictionary <NodeIndex, StatementBlock>(); int firstIterPostBlockCount = 0; IConditionStatement firstIterPostStatement = null; // must include back edges for computing InitializerSets DependencyGraph2 g = new DependencyGraph2(context, isc, DependencyGraph2.BackEdgeHandling.Include, delegate(IWhileStatement iws) { blocks.Add(currentBlock); currentBlock = new Loop(iws); }, delegate(IWhileStatement iws) { blocks.Add(currentBlock); currentBlock = new StraightLine(); }, delegate(IConditionStatement ics) { firstIterPostBlockCount++; firstIterPostStatement = ics; }, delegate(IConditionStatement ics) { firstIterPostBlockCount--; }, delegate(IStatement ist, int index) { if (firstIterPostBlockCount > 0) { ((Loop)currentBlock).firstIterPostBlock.Add(index); } currentBlock.indices.Add(index); blockOfNode[index] = currentBlock; }); var dependencyGraph = g.dependencyGraph; blocks.Add(currentBlock); Set <NodeIndex> usedNodes = Set <NodeIndex> .FromEnumerable(g.outputNodes); Set <NodeIndex> usedBySelf = new Set <NodeIndex>(); // loop blocks in reverse order for (int i = blocks.Count - 1; i >= 0; i--) { StatementBlock block = blocks[i]; if (block is Loop loop) { if (!pruneDeadCode) { usedNodes = CollectUses(dependencyGraph, block.indices); } else { usedBySelf.Clear(); block.indices = PruneDeadNodesCyclic(g, block.indices, usedNodes, usedBySelf, out List <int> tailStmts); // modifies usedNodes loop.tail = tailStmts; } RemoveSuffix(block.indices, loop.firstIterPostBlock); } else { // StraightLine if (pruneDeadCode) { block.indices = PruneDeadNodes(g, block.indices, usedNodes, usedBySelf); // modifies usedNodes } } AddLoopInitializers(block, usedNodes, blockOfNode, g); } IList <IStatement> sc = Builder.StmtCollection(); foreach (StatementBlock block in blocks) { if (block is Loop loop) { context.OpenStatement(loop.loopStatement); IWhileStatement ws = Builder.WhileStmt(loop.loopStatement); context.SetPrimaryOutput(ws); IList <IStatement> sc2 = ws.Body.Statements; foreach (NodeIndex i in loop.indices) { IStatement st = ConvertStatement(g.nodes[i]); sc2.Add(st); } context.CloseStatement(loop.loopStatement); context.InputAttributes.CopyObjectAttributesTo(loop.loopStatement, context.OutputAttributes, ws); sc.Add(ws); List <IStatement> initStmts = new List <IStatement>(); initStmts.AddRange(loop.initializers); if (loop.firstIterPostBlock.Count > 0) { var firstIterPostStatements = loop.firstIterPostBlock.Select(i => g.nodes[i]); var thenBlock = Builder.BlockStmt(); ConvertStatements(thenBlock.Statements, firstIterPostStatements); var firstIterPostStmt = Builder.CondStmt(firstIterPostStatement.Condition, thenBlock); context.OutputAttributes.Set(firstIterPostStmt, new FirstIterationPostProcessingBlock()); sc2.Add(firstIterPostStmt); loopMergingInfo.AddNode(firstIterPostStmt); } context.OutputAttributes.Remove <InitializerSet>(ws); context.OutputAttributes.Set(ws, new InitializerSet(initStmts)); if (loop.tail != null) { foreach (NodeIndex i in loop.tail) { IStatement st = g.nodes[i]; sc.Add(st); } } } else { foreach (NodeIndex i in block.indices) { IStatement st = ConvertStatement(g.nodes[i]); sc.Add(st); } } } return(sc); }
protected override void DoConvertMethodBody(IList <IStatement> outputs, IList <IStatement> inputs) { List <int> whileNumberOfNode = new List <int>(); List <int> fusedCountOfNode = new List <int>(); List <List <IStatement> > containersOfNode = new List <List <IStatement> >(); // the code may have multiple while(true) loops, however these must be disjoint. // therefore we treat 'while' as one container, but give each loop a different 'while number'. int outerWhileCount = 0; int currentOuterWhileNumber = 0; int currentFusedCount = 0; List <Set <IVariableDeclaration> > loopVarsOfWhileNumber = new List <Set <IVariableDeclaration> >(); // build the dependency graph var g = new DependencyGraph2(context, inputs, DependencyGraph2.BackEdgeHandling.Ignore, delegate(IWhileStatement iws) { if (iws is IFusedBlockStatement) { if (iws.Condition is IVariableReferenceExpression) { currentFusedCount++; } } else { outerWhileCount++; currentOuterWhileNumber = outerWhileCount; } }, delegate(IWhileStatement iws) { if (iws is IFusedBlockStatement) { if (iws.Condition is IVariableReferenceExpression) { currentFusedCount--; } } else { currentOuterWhileNumber = 0; } }, delegate(IConditionStatement ics) { }, delegate(IConditionStatement ics) { }, delegate(IStatement ist, int targetIndex) { int whileNumber = currentOuterWhileNumber; whileNumberOfNode.Add(whileNumber); fusedCountOfNode.Add(currentFusedCount); List <IStatement> containers = new List <IStatement>(); LoopMergingTransform.UnwrapStatement(ist, containers); containersOfNode.Add(containers); for (int i = 0; i < currentFusedCount; i++) { IForStatement ifs = (IForStatement)containers[i]; if (ifs != null) { var loopVar = Recognizer.LoopVariable(ifs); if (loopVarsOfWhileNumber.Count <= whileNumber) { while (loopVarsOfWhileNumber.Count <= whileNumber) { loopVarsOfWhileNumber.Add(new Set <IVariableDeclaration>()); } } Set <IVariableDeclaration> loopVars = loopVarsOfWhileNumber[whileNumber]; loopVars.Add(loopVar); } } }); var nodes = g.nodes; var dependencyGraph = g.dependencyGraph; for (int whileNumber = 1; whileNumber < loopVarsOfWhileNumber.Count; whileNumber++) { foreach (var loopVar in loopVarsOfWhileNumber[whileNumber]) { // Any statement (in the while loop) that has a forward descendant and a backward descendant will be cloned, so we want to minimize the number of such nodes. // The free variables in this problem are the loop directions at the leaf statements, since all other loop directions are forced by these. // We find the optimal labeling of the free variables by solving a min cut problem on a special network. // The network is constructed so that the cost of a cut is equal to the number of statements that will be cloned. // The network has 2 nodes for every statement: an in-node and an out-node. // For a non-leaf statement, there is a capacity 1 edge from the in-node to out-node. This edge is cut when the statement is cloned. // For a leaf statement, there is an infinite capacity edge in both directions, or equivalently a single node. // If statement A depends on statement B, then there is an infinite capacity edge from in-A to in-B, and from out-B to out-A, // representing the fact that cloning A requires cloning B, but not the reverse. // If a statement must appear with a forward loop, it is connected to the source. // If a statement must appear with a backward loop, it is connected to the sink. // construct a capacitated graph int inNodeStart = 0; int outNodeStart = inNodeStart + dependencyGraph.Nodes.Count; int sourceNode = outNodeStart + dependencyGraph.Nodes.Count; int sinkNode = sourceNode + 1; int cutNodeCount = sinkNode + 1; Func <NodeIndex, int> getInNode = node => node + inNodeStart; Func <NodeIndex, int> getOutNode = node => node + outNodeStart; IndexedGraph network = new IndexedGraph(cutNodeCount); const float infinity = 1000000f; List <float> capacity = new List <float>(); List <NodeIndex> nodesOfInterest = new List <NodeIndex>(); foreach (var node in dependencyGraph.Nodes) { if (whileNumberOfNode[node] != whileNumber) { continue; } NodeIndex source = node; List <IStatement> containersOfSource = containersOfNode[source]; bool hasLoopVar = containersOfSource.Any(container => container is IForStatement && Recognizer.LoopVariable((IForStatement)container) == loopVar); if (!hasLoopVar) { continue; } nodesOfInterest.Add(node); IStatement sourceSt = nodes[source]; var readAfterWriteEdges = dependencyGraph.EdgesOutOf(source).Where(edge => !g.isWriteAfterRead[edge]); bool isLeaf = true; int inNode = getInNode(node); int outNode = getOutNode(node); foreach (var target in readAfterWriteEdges.Select(dependencyGraph.TargetOf)) { List <IStatement> containersOfTarget = containersOfNode[target]; IStatement targetSt = nodes[target]; ForEachMatchingLoopVariable(containersOfSource, containersOfTarget, (loopVar2, afs, bfs) => { if (loopVar2 == loopVar) { int inTarget = getInNode(target); int outTarget = getOutNode(target); network.AddEdge(inTarget, inNode); capacity.Add(infinity); network.AddEdge(outNode, outTarget); capacity.Add(infinity); isLeaf = false; } }); } if (isLeaf) { if (debug) { log.Add($"loopVar={loopVar.Name} leaf {sourceSt}"); } network.AddEdge(inNode, outNode); capacity.Add(infinity); network.AddEdge(outNode, inNode); capacity.Add(infinity); } else { network.AddEdge(inNode, outNode); capacity.Add(1f); } int fusedCount = fusedCountOfNode[node]; Direction desiredDirectionOfSource = GetDesiredDirection(loopVar, containersOfSource, fusedCount); if (desiredDirectionOfSource == Direction.Forward) { if (debug) { log.Add($"loopVar={loopVar.Name} forward {sourceSt}"); } network.AddEdge(sourceNode, inNode); capacity.Add(infinity); } else if (desiredDirectionOfSource == Direction.Backward) { if (debug) { log.Add($"loopVar={loopVar.Name} backward {sourceSt}"); } network.AddEdge(outNode, sinkNode); capacity.Add(infinity); } } network.IsReadOnly = true; // compute the min cut MinCut <NodeIndex, EdgeIndex> mc = new MinCut <EdgeIndex, EdgeIndex>(network, e => capacity[e]); mc.Sources.Add(sourceNode); mc.Sinks.Add(sinkNode); Set <NodeIndex> sourceGroup = mc.GetSourceGroup(); foreach (NodeIndex node in nodesOfInterest) { IStatement sourceSt = nodes[node]; bool forwardIn = sourceGroup.Contains(getInNode(node)); bool forwardOut = sourceGroup.Contains(getOutNode(node)); if (forwardIn != forwardOut) { if (debug) { log.Add($"loopVar={loopVar.Name} will clone {sourceSt}"); } } else if (forwardIn) { if (debug) { log.Add($"loopVar={loopVar.Name} wants forward {sourceSt}"); } } else { if (debug) { log.Add($"loopVar={loopVar.Name} wants backward {sourceSt}"); } var containers = containersOfNode[node]; bool isForwardLoop = true; foreach (var container in containers) { if (container is IForStatement) { IForStatement ifs = (IForStatement)container; if (Recognizer.LoopVariable(ifs) == loopVar) { isForwardLoop = Recognizer.IsForwardLoop(ifs); } } } if (isForwardLoop) { Set <IVariableDeclaration> loopVarsToReverse; if (!loopVarsToReverseInStatement.TryGetValue(sourceSt, out loopVarsToReverse)) { // TODO: re-use equivalent sets loopVarsToReverse = new Set <IVariableDeclaration>(); loopVarsToReverseInStatement.Add(sourceSt, loopVarsToReverse); } loopVarsToReverse.Add(loopVar); } } } } } base.DoConvertMethodBody(outputs, inputs); }