protected override ICode VisitIf(StmtIf s) { if (!VisitorFindContinuations.Any(s)) { // 'If' contains no continuations, so no distribution can be done return(s); } if (VisitorOnlyStatements.Only(s, Stmt.NodeType.If, Stmt.NodeType.Continuation) && this.ifInfo == null) { // 'If' only contains continuations, so no distribution can be done // Must visit base method to find contained continuations return(base.VisitIf(s)); } bool finalise = false; if (this.ifInfo == null) { finalise = true; this.ifInfo = new IfInfo(); } this.ifInfo.Conditions.Push(s.Condition); var then = this.Visit(s.Then); this.ifInfo.Conditions.Pop(); this.ifInfo.Conditions.Push(this.ctx.ExprGen.NotAutoSimplify(s.Condition)); var @else = this.Visit(s.Else); this.ifInfo.Conditions.Pop(); if (then != s.Then || @else != s.Else) { var @if = new StmtIf(s.Ctx, s.Condition, (Stmt)then, (Stmt)@else); if (finalise && this.ifInfo.AddToIf.Any()) { var ifStmts = this.ifInfo.AddToIf.GroupBy(x => x.Item1.To, x => x.Item2).Select(x => new StmtIf(s.Ctx, x.Aggregate((a, b) => this.ctx.ExprGen.Or(a, b)), this.ifInfo.AddToIf.First(y => y.Item1.To == x.Key).Item1, null) ); var stmts = new Stmt[] { @if }.Concat(ifStmts).ToArray(); this.ifInfo = null; return(new StmtBlock(s.Ctx, stmts)); } else { if (finalise) { this.ifInfo = null; } return(@if); } } else { // In this case, no continuations will have been found, so there cannot be any conditions to add if (finalise) { this.ifInfo = null; } return(s); } }
public static IEnumerable <StmtContinuation> Get(ICode root) { var v = new VisitorFindContinuations(); v.Visit(root); return(v.Continuations); }
public static bool Any(ICode root) { var v = new VisitorFindContinuations(); v.Visit(root); return(v.Continuations.Any()); }
protected override ICode VisitCil(StmtCil s) { var bsi = this.blockStartInfos[s]; var stack = new Stack <Expr>(bsi.Stack.Reverse()); var locals = bsi.Locals.Cast <Expr>().ToArray(); var args = bsi.Args.Cast <Expr>().ToArray(); var orgStack = stack.ToArray(); var orgLocals = locals.ToArray(); var orgArgs = args.ToArray(); var cil = new CilProcessor(this.ctx, stack, locals, args, this.instResults); var stmts = new List <Stmt>(); switch (s.BlockType) { case StmtCil.SpecialBlock.Normal: foreach (var inst in s.Insts) { var stmt = cil.Process(inst); if (stmt != null) { stmts.Add(stmt); } } break; case StmtCil.SpecialBlock.Start: // Do nothing break; case StmtCil.SpecialBlock.End: stmts.Add(cil.ProcessReturn()); break; default: throw new InvalidOperationException("Invalid block type: " + s.BlockType); } this.stmtVarsChanged.Add(s, new StmtVarChanged { Stack = stack.Zip(orgStack, (a, b) => a == b).ToArray(), Locals = locals.Zip(orgLocals, (a, b) => a == b).ToArray(), Args = args.Zip(orgArgs, (a, b) => a == b).ToArray(), }); // Merge phi's var continuations = VisitorFindContinuations.Get(s); foreach (var continuation in continuations) { this.CreateOrMergeBsi(continuation.To, stack.ToArray(), locals, args); } // End var next = (Stmt)this.Visit(s.EndCil); stmts.Add(next); return(new StmtBlock(this.ctx, stmts)); }
private Tuple <Stmt, Stmt> RemoveContinuation(Stmt s) { // Return Item1 = Statement 's' with continuation removed // Return Item2 = Block that only continuation in 's' points to, otherwise null // This must not return a null statement if empty, as then the 'try' statements won't know // if it is a 'catch' or 'finally' statement. Uses a StmtEmpty instead. var contCount = VisitorFindContinuations.Get(s).Count(); if (contCount == 0) { // Blocks with no continuations must end with a 'throw' return(Tuple.Create(s, (Stmt)null)); } if (contCount != 1) { return(null); } switch (s.StmtType) { case Stmt.NodeType.Block: var statements = ((StmtBlock)s).Statements.ToArray(); if (statements.Length == 0) { return(null); } if (statements.Last().StmtType != Stmt.NodeType.Continuation) { return(null); } var cont = (StmtContinuation)statements.Last(); if (!cont.LeaveProtectedRegion) { return(null); } if (statements.Length == 1) { return(Tuple.Create((Stmt) new StmtEmpty(s.Ctx), cont.To)); } return(Tuple.Create((Stmt) new StmtBlock(s.Ctx, statements.Take(statements.Length - 1)), cont.To)); case Stmt.NodeType.Continuation: var sCont = (StmtContinuation)s; if (!sCont.LeaveProtectedRegion) { return(null); } return(Tuple.Create((Stmt) new StmtEmpty(s.Ctx), sCont.To)); default: return(null); } }
protected override ICode VisitCil(StmtCil s) { var endStackSize = this.StackSizeAnalysis(s.Insts, s.StartStackSize); s.EndStackSize = endStackSize; var conts = VisitorFindContinuations.Get(s.EndCil); Action <Stmt, int> setStackSize = null; setStackSize = (stmt, stackSize) => { // Set all try and catch stacksizes recursively, to handle multiple trys start on the same instruction switch (stmt.StmtType) { case Stmt.NodeType.Cil: ((StmtCil)stmt).StartStackSize = stackSize; return; case Stmt.NodeType.Try: var stmtTry = (StmtTry)stmt; setStackSize(stmtTry.Try, stackSize); if (stmtTry.Catches != null) { setStackSize(stmtTry.Catches.First().Stmt, 1); } // 'Finally' stack sizes do not need setting, as they will have defaulted to 0 // and this will always be correct return; case Stmt.NodeType.Return: // do nothing return; default: throw new NotSupportedException("Should not be seeing: " + stmt.StmtType); } }; foreach (var cont in conts) { setStackSize(cont.To, endStackSize); } return(base.VisitCil(s)); }
public static bool Any(ICode root) { var v = new VisitorFindContinuations(); v.Visit(root); return v.Continuations.Any(); }
public static IEnumerable<StmtContinuation> Get(ICode root) { var v = new VisitorFindContinuations(); v.Visit(root); return v.Continuations; }
protected override ICode VisitTry(StmtTry s) { var @try = this.RemoveContinuation(s.Try); if (@try != null) { if (s.Catches != null) { if (s.Catches.Count() != 1) { throw new InvalidOperationException("Should only ever see 1 catch here"); } var sCatch = s.Catches.First(); var @catch = this.RemoveContinuation(sCatch.Stmt); if (@catch != null) { if ((@try.Item2 == null || @catch.Item2 == null || @try.Item2 == @catch.Item2) && (@try.Item2 != null || @catch.Item2 != null)) { var newTry = new StmtTry(s.Ctx, @try.Item1, new[] { new StmtTry.Catch(@catch.Item1, sCatch.ExceptionVar) }, null); var newCont = new StmtContinuation(s.Ctx, @try.Item2 ?? @catch.Item2, false); return(new StmtBlock(s.Ctx, newTry, newCont)); } } // Special case // When 'leave' CIL branch to different instructions, allow specific code to be // moved inside the 'try' or 'catch' block. It should be impossible for this code to throw an exception var tryTos = VisitorFindContinuations.Get(@try.Item2); if (tryTos.Count() == 1 && tryTos.First().To == @catch.Item2 && @try.Item2.StmtType == Stmt.NodeType.Block) { var try2Stmts = ((StmtBlock)@try.Item2).Statements.ToArray(); var s0 = try2Stmts.Take(try2Stmts.Length - 1); if (s0.All(x => x.StmtType == Stmt.NodeType.Assignment)) { var sN = try2Stmts.Last(); if (sN.StmtType == Stmt.NodeType.Continuation) { var newTry = new StmtTry(s.Ctx, new StmtBlock(s.Ctx, @try.Item1, new StmtBlock(s.Ctx, s0), new StmtContinuation(s.Ctx, ((StmtContinuation)sN).To, true)), s.Catches, null); return(newTry); } } } } if (s.Finally != null) { var @finally = this.RemoveContinuation(s.Finally); if (@finally != null) { if ((@try.Item2 == null || @finally.Item2 == null || @try.Item2 == @finally.Item2) && (@try.Item2 != null || @finally.Item2 != null)) { var newTry = new StmtTry(s.Ctx, @try.Item1, null, @finally.Item1); var newCont = new StmtContinuation(s.Ctx, @try.Item2 ?? @finally.Item2, false); return(new StmtBlock(s.Ctx, newTry, newCont)); } } } // TODO: This is a hack for badly handling fault handlers. They are ignored at the moment if (s.Catches == null && s.Finally == null) { var cont = @try.Item2 == null ? null : new StmtContinuation(s.Ctx, @try.Item2, false); return(new StmtBlock(s.Ctx, @try.Item1, cont)); } } return(base.VisitTry(s)); }
protected override ICode VisitContinuation(StmtContinuation s) { // TODO: Why is this de-recursing blocks in continuations? Why doesn't it just derecurse itself (if possible)??? if (s.To.StmtType != Stmt.NodeType.Block) { return(base.VisitContinuation(s)); } if (!this.seen.Add(s.To)) { return(base.VisitContinuation(s)); } var block = (StmtBlock)s.To; foreach (var stmt in block.Statements) { if (stmt.StmtType == Stmt.NodeType.Continuation) { // Continuation not inside 'if' var sCont = (StmtContinuation)stmt; if (sCont.To == block) { // Recursive, so derecurse with no condition on loop var body = new StmtBlock(s.Ctx, block.Statements.TakeWhile(x => x != stmt).ToArray()); var replaceWith = new StmtDoLoop(s.Ctx, body, new ExprLiteral(s.Ctx, true, s.Ctx.Boolean)); this.replaces.Add(s.To, replaceWith); return(base.VisitContinuation(s)); } } if (stmt.StmtType == Stmt.NodeType.If) { // Continuation only statement within 'if' with only a 'then' clause var sIf = (StmtIf)stmt; if (sIf.Else == null && sIf.Then.StmtType == Stmt.NodeType.Continuation) { var sThen = (StmtContinuation)sIf.Then; if (sThen.To == block) { // Recursive, so derecurse var condition = sIf.Condition; var bodyStmts = block.Statements.TakeWhile(x => x != stmt).ToArray(); var bodyLast = bodyStmts.LastOrDefault(); var body = new StmtBlock(s.Ctx, bodyStmts); var loop = new StmtDoLoop(s.Ctx, body, condition); var afterLoop = block.Statements.SkipWhile(x => x != stmt).Skip(1).ToArray(); if (VisitorFindContinuations.Get(new StmtBlock(s.Ctx, afterLoop)).Any(x => x.To == block)) { // Cannot de-recurse yet, must wait for continuations to be merged return(base.VisitContinuation(s)); } Stmt replaceWith; if (afterLoop.Any()) { var loopAndAfter = new[] { loop }.Concat(afterLoop).ToArray(); replaceWith = new StmtBlock(s.Ctx, loopAndAfter); } else { replaceWith = loop; } this.replaces.Add(s.To, replaceWith); return(base.VisitContinuation(s)); } } } if (VisitorFindContinuations.Any(stmt)) { // Another continuation present, cannot derecurse break; } } return(base.VisitContinuation(s)); }
protected override ICode VisitSwitch(StmtSwitch s) { // If switch statement contains no continuations then it doesn't need processing if (!VisitorFindContinuations.Any(s)) { return(base.VisitSwitch(s)); } var ctx = s.Ctx; // If any cases go to the same continuation as the default case, remove them if (s.Default != null && s.Default.StmtType == Stmt.NodeType.Continuation) { var defaultCont = (StmtContinuation)s.Default; var sameAsDefault = s.Cases .Where(x => x.Stmt != null && x.Stmt.StmtType == Stmt.NodeType.Continuation && ((StmtContinuation)x.Stmt).To == defaultCont.To) .ToArray(); if (sameAsDefault.Any()) { var cases = s.Cases.Except(sameAsDefault); return(new StmtSwitch(ctx, s.Expr, cases, s.Default)); } } // If multiple case statements all go the same continuation, then put them consecutively var groupedByTo = s.Cases .Where(x => x.Stmt != null && x.Stmt.StmtType == Stmt.NodeType.Continuation) .GroupBy(x => ((StmtContinuation)x.Stmt).To) .Where(x => x.Count() >= 2) .ToArray(); if (groupedByTo.Any()) { var cases = s.Cases.Except(groupedByTo.SelectMany(x => x)); var combinedCases = groupedByTo.SelectMany(x => { var same = x.ToArray(); var last = same.Last(); var sameCases = same.Take(same.Length - 1).Select(y => new StmtSwitch.Case(y.Value, null)) .Concat(new StmtSwitch.Case(last.Value, last.Stmt)); return(sameCases); }); var allCases = cases.Concat(combinedCases).ToArray(); return(new StmtSwitch(ctx, s.Expr, allCases, s.Default)); } Func <Stmt, IEnumerable <StmtContinuation> > getSingleFinalContinuation = stmt => { if (stmt == null) { return(Enumerable.Empty <StmtContinuation>()); } var contCount = VisitorFindContinuations.Get(stmt).Count(); if (contCount == 0) { // Case contains return or throw return(Enumerable.Empty <StmtContinuation>()); } if (contCount == 1) { if (stmt.StmtType == Stmt.NodeType.Continuation) { return(new[] { (StmtContinuation)stmt }); } if (stmt.StmtType == Stmt.NodeType.Block) { var stmtBlock = (StmtBlock)stmt; var last = stmtBlock.Statements.LastOrDefault(); if (last != null && last.StmtType == Stmt.NodeType.Continuation) { return(new[] { (StmtContinuation)last }); } } } return(new StmtContinuation[] { null }); }; var conts = s.Cases.Select(x => x.Stmt).Concat(s.Default).SelectMany(x => getSingleFinalContinuation(x)).ToArray(); if (conts.All(x => x != null)) { // If all cases end with a continuation to the same stmt, then put that stmt after the switch and remove all continuations if (conts.AllSame(x => x.To)) { Func <Stmt, Stmt> removeCont = stmt => { if (stmt == null) { return(null); } switch (stmt.StmtType) { case Stmt.NodeType.Continuation: return(new StmtBreak(ctx)); case Stmt.NodeType.Block: var sBlock = (StmtBlock)stmt; var stmts = sBlock.Statements.ToArray(); if (stmts.Last().StmtType == Stmt.NodeType.Continuation) { stmts = stmts.Take(stmts.Length - 1).Concat(new StmtBreak(ctx)).ToArray(); return(new StmtBlock(ctx, stmts)); } else { return(stmt); } default: return(stmt); } }; var cases = s.Cases.Select(x => new StmtSwitch.Case(x.Value, removeCont(x.Stmt))).ToArray(); var @switch = new StmtSwitch(ctx, s.Expr, cases, removeCont(s.Default)); return(new StmtBlock(ctx, @switch, conts[0])); } else if (this.lastChance) { // HACK: Change it into multiple if statements var multiValues = new List <int>(); var converted = s.Cases.Aggregate(s.Default, (@else, @case) => { multiValues.Add(@case.Value); if (@case.Stmt == null) { return(@else); } else { var cond = multiValues.Aggregate((Expr)ctx.Literal(false), (expr, caseValue) => { return(ctx.ExprGen.Or(expr, ctx.ExprGen.Equal(s.Expr, ctx.Literal(caseValue)))); }); multiValues.Clear(); var @if = new StmtIf(ctx, cond, @case.Stmt, @else); return(@if); } }); return(converted); } // If some cases end in a continuation that itself ends in a continuation that other cases end with // then use an extra variable to store whether to execute the intermediate code // TODO: This is too specific, need a more general-purpose solution to the problem where cases // don't all end by going to the same place //var contTos = conts.Select(x => x.To).Distinct().ToArray(); //var finalContTos = contTos.Select(x => getSingleFinalContinuation(x).Select(y => y.NullThru(z => z.To))).SelectMany(x => x).ToArray(); //if (!finalContTos.Any(x => x == null)) { // // All continuations are fully substituted // var distinctFinalContTos = finalContTos.Distinct().ToArray(); // if (distinctFinalContTos.Length == 1) { // var selector = ctx.Local(ctx.Int32); // var inIfCont = contTos.Single(x => x != distinctFinalContTos[0]); // var inIf = new StmtContinuation(ctx, inIfCont, false); // var afterIf = new StmtContinuation(ctx, distinctFinalContTos[0], false); // var allCasesTo = new StmtBlock(ctx, // new StmtIf(ctx, ctx.ExprGen.Equal(selector, ctx.Literal(1)), inIf, null), // afterIf); // Func<Stmt, Stmt> adjustCont = stmt => { // var cont = VisitorFindContinuations.Get(stmt).Single(); // var newCont = new StmtContinuation(ctx, allCasesTo, false); // var contChanged = (Stmt)VisitorReplace.V(stmt, cont, newCont); // var sValue = cont.To == inIf.To ? 1 : 0; // var withSelectorSet = new StmtBlock(ctx, // new StmtAssignment(ctx, selector, ctx.Literal(sValue)), // contChanged); // return withSelectorSet; // }; // var cases = s.Cases.Select(x => new StmtSwitch.Case(x.Value, adjustCont(x.Stmt))).ToArray(); // var @switch = new StmtSwitch(ctx, s.Expr, cases, adjustCont(s.Default)); // return @switch; // } else { // throw new NotImplementedException(); // } //} } return(base.VisitSwitch(s)); }
private void CreateOrMergeBsi(Stmt s, Expr[] stack, Expr[] locals, Expr[] args) { if (s.StmtType == Stmt.NodeType.Try) { var sTry = (StmtTry)s; // It is fine to use 'locals' and 'args' in catch/finally because the phi clustering performed later // will conglomerate all the necessary variables if (sTry.Catches != null) { var catch0 = sTry.Catches.First(); this.CreateOrMergeBsi(catch0.Stmt, new Expr[] { catch0.ExceptionVar }, locals, args); } if (sTry.Finally != null) { this.CreateOrMergeBsi(sTry.Finally, new Expr[0], locals, args); } this.CreateOrMergeBsi(sTry.Try, stack, locals, args); return; } if (s.StmtType != Stmt.NodeType.Cil) { throw new InvalidCastException("Should not be seeing: " + s.StmtType); } // Perform create/merge Func <Expr, IEnumerable <Expr> > flattenPhiExprs = null; flattenPhiExprs = e => { if (e.ExprType == Expr.NodeType.VarPhi) { return(((ExprVarPhi)e).Exprs.SelectMany(x => flattenPhiExprs(x))); } return(new[] { e }); }; Action <ExprVarPhi[], IEnumerable <Expr> > merge = (bsiVars, thisVars) => { foreach (var v in bsiVars.Zip(thisVars, (a, b) => new { phi = a, add = b })) { if (v.add != null) { v.phi.Exprs = flattenPhiExprs(v.add).Concat(v.phi.Exprs).Where(x => x != null && x != v.phi).Distinct().ToArray(); } } }; BlockInitInfo bsi; if (!this.blockStartInfos.TryGetValue(s, out bsi)) { Func <IEnumerable <Expr>, ExprVarPhi[]> create = exprs => exprs.Select(x => { if (x == null) { return(new ExprVarPhi(this.ctx) { Exprs = new Expr[0] }); } if (x.ExprType == Expr.NodeType.VarPhi) { return((ExprVarPhi)x); } return(new ExprVarPhi(this.ctx) { Exprs = new[] { x } }); }).ToArray(); bsi = new BlockInitInfo { Stack = create(stack), Locals = create(locals), Args = create(args), }; this.blockStartInfos.Add(s, bsi); } else { merge(bsi.Stack, stack); merge(bsi.Locals, locals); merge(bsi.Args, args); // Forward-merge through already-processed nodes for vars that are not changed in a node var fmSeen = new HashSet <Stmt>(); Action <Stmt> forwardMerge = null; forwardMerge = (stmt) => { if (fmSeen.Add(stmt)) { var fmBsi = this.blockStartInfos.ValueOrDefault(stmt); var fmChanges = this.stmtVarsChanged.ValueOrDefault(stmt); if (fmBsi != null && fmChanges != null) { var fmStack = fmBsi.Stack.Take(fmChanges.Stack.Length).Select((x, i) => fmChanges.Stack[i] ? x : null).ToArray(); var fmLocals = fmBsi.Locals.Take(fmChanges.Locals.Length).Select((x, i) => fmChanges.Locals[i] ? x : null).ToArray(); var fmArgs = fmBsi.Args.Take(fmChanges.Args.Length).Select((x, i) => fmChanges.Args[i] ? x : null).ToArray(); if (fmStack.Any(x => x != null) || fmLocals.Any(x => x != null) || fmArgs.Any(x => x != null)) { var fmConts = VisitorFindContinuations.Get(stmt); foreach (var fmCont in fmConts) { var fmContBsi = this.blockStartInfos.ValueOrDefault(fmCont.To); if (fmContBsi != null) { merge(fmContBsi.Stack, fmStack); merge(fmContBsi.Locals, fmLocals); merge(fmContBsi.Args, fmArgs); forwardMerge(fmCont.To); } } } } } }; forwardMerge(s); } }