public void GenerateStates() { var lastState = new State(this); lastState.Statements.Add(ReturnOutOfState()); currentState = new State(this) { NextState = lastState }; node.Accept(this); // Post-process goto statements if (labelStates.Any()) { var gotoSubstituter = new GotoSubstituter(compilation, labelStates); foreach (var state in states) { state.Statements = state.Statements.Select(x => (StatementSyntax)x.Accept(gotoSubstituter)).ToList(); } } var lastStatement = states.Last().Statements.LastOrDefault(); if (lastStatement == null || (!(lastStatement is BreakStatementSyntax) && !(lastStatement is ReturnStatementSyntax))) states.Last().Statements.Add(Cs.Break()); }
protected void MaybeCreateNewState() { if (currentState.Statements.Any()) { var thisState = new State(this); var oldNextState = currentState.NextState; thisState.NextState = oldNextState; currentState.NextState = thisState; Close(currentState); currentState = thisState; } }
protected StatementSyntax[] CaptureState(CSharpSyntaxNode node, State nextState, State breakState) { var catchBatch = new State(this, true) { NextState = nextState }; var oldState = currentState; currentState = catchBatch; node.Accept(this); // If after walking the node, it might leave the current state in an unclosed state. We want to make sure it's closed. if (currentState != breakState && currentState != catchBatch && currentState != nextState && currentState != oldState) { Close(currentState); } return catchBatch.Statements.ToArray(); }
public void CloseTo(State fromState, State toState) { if (fromState.IsClosed) return; if (toState == null) { // fromState.Add(ChangeState(toState)); fromState.Add(GotoTop()); SetClosed(fromState); // SetClosed(Go); return; } // throw new ArgumentNullException("toState"); fromState.Add(ChangeState(toState)); fromState.Add(GotoTop()); SetClosed(fromState); }
public static StatementSyntax ChangeState(State newState) { return Cs.Express(Cs.This().Member(state).Assign(Cs.Integer(newState.Index))); }
protected void SetClosed(State state) { state.IsClosed = true; if (state.Germ != null) state.Germ(state); }
public void Close(State state) { CloseTo(state, state.NextState); }
protected State GetNextState(StatementSyntax node) { var nextState = currentState.NextState; var next = node.GetNextStatement(); if (next != null && !(next is EmptyStatementSyntax)) { nextState = new State(this) { NextState = currentState.NextState, Germ = currentState.Germ }; } return nextState; }
public override void VisitTryStatement(TryStatementSyntax node) { if (!YieldChecker.HasSpecialStatement(node)) { currentState.Add(StateMachineThisFixer.Fix(node)); } else { if (node.Catches.Any()) throw new Exception("Yield statements cannot be contained inside try/catch blocks."); var nextState = GetNextState(node); MaybeCreateNewState(); var tryState = currentState; tryState.NextState = nextState; var exceptionName = SyntaxFactory.Identifier("_ex" + exceptionNameCounter++); var finallyState = new State(this) { NextState = nextState, BreakState = nextState }; foreach (var finallyStatement in node.Finally.Block.Statements) finallyState.Add(finallyStatement); finallyState.Add(Cs.If(Cs.This().Member(exceptionName).NotEqualTo(Cs.Null()), Cs.Throw(Cs.This().Member(exceptionName)))); Close(finallyState); node = (TryStatementSyntax)HoistVariable(node, exceptionName, SyntaxFactory.ParseTypeName("System.Exception")); tryState.NextState = finallyState; tryState.Germ = yieldState => { var gotoFinally = SyntaxFactory.Block( Cs.Express(Cs.This().Member(exceptionName).Assign(SyntaxFactory.IdentifierName(exceptionName))), ChangeState(finallyState), GotoTop() ); var statements = yieldState.Statements.ToArray(); yieldState.Statements.Clear(); yieldState.Statements.Add(Cs.Try().WithBlock(Cs.Block(statements)).WithCatches(SyntaxFactory.List(new[] { SyntaxFactory.CatchClause(SyntaxFactory.CatchDeclaration(SyntaxFactory.ParseTypeName("System.Exception"), exceptionName), null, gotoFinally) }))); }; node.Block.Accept(this); if (!tryState.IsClosed) { CloseTo(tryState, finallyState); } currentState = nextState; } }
public override void VisitForEachStatement(ForEachStatementSyntax node) { if (!YieldChecker.HasSpecialStatement(node)) { currentState.Add(StateMachineThisFixer.Fix(node)); } else { // Convert the variable declaration in the for loop var semanticModel = compilation.GetSemanticModel(node.SyntaxTree); var symbol = (ILocalSymbol)ModelExtensions.GetDeclaredSymbol(semanticModel, node); var targetType = ModelExtensions.GetTypeInfo(semanticModel, node.Expression); // Hoist the variable into a field node = (ForEachStatementSyntax)HoistVariable(node, node.Identifier, symbol.Type.ToTypeSyntax()); // Hoist the enumerator into a field var enumerator = SyntaxFactory.Identifier(node.Identifier + "_enumerator"); var genericEnumeratorType = compilation.FindType("System.Collections.Generic.IEnumerable`1"); var elementType = targetType.ConvertedType.GetGenericArgument(genericEnumeratorType, 0); var enumeratorType = elementType == null ? SyntaxFactory.ParseTypeName("System.Collections.IEnumerator") : SyntaxFactory.ParseTypeName("System.Collections.Generic.IEnumerator<" + elementType.ToDisplayString() + ">"); node = (ForEachStatementSyntax)HoistVariable(node, enumerator, enumeratorType); currentState.Add(Cs.Express(SyntaxFactory.IdentifierName(enumerator).Assign(SyntaxFactory.InvocationExpression(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, StateMachineThisFixer.Fix(node.Expression), SyntaxFactory.IdentifierName("GetEnumerator")))))); // Mostly the same as while loop from here (key word, "mostly"; hence the lack of factoring here) MaybeCreateNewState(); var nextState = GetNextState(node); var iterationState = currentState; iterationState.NextState = iterationState; iterationState.BreakState = nextState; var bodyBatch = new State(this, true) { NextState = iterationState.NextState }; // Assign current item bodyBatch.Add(Cs.Express(SyntaxFactory.IdentifierName(node.Identifier).Assign(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.IdentifierName(enumerator), SyntaxFactory.IdentifierName("Current"))))); currentState = bodyBatch; node.Statement.Accept(this); var moveNext = SyntaxFactory.InvocationExpression(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.IdentifierName(enumerator), SyntaxFactory.IdentifierName("MoveNext"))); var forStatement = SyntaxFactory.WhileStatement(moveNext, SyntaxFactory.Block(bodyBatch.Statements)); iterationState.Statements.Add(forStatement); CloseTo(iterationState, nextState); if (currentState != nextState) { Close(currentState); } currentState = nextState; } }
public override void VisitForStatement(ForStatementSyntax node) { if (!YieldChecker.HasSpecialStatement(node)) { currentState.Add(StateMachineThisFixer.Fix(node)); } else { var semanticModel = compilation.GetSemanticModel(node.SyntaxTree); // Convert the variable declarations in the for loop if (node.Declaration != null) { foreach (var variable in node.Declaration.Variables.Where(x => x.Initializer != null)) { var assignment = SyntaxFactory.IdentifierName(variable.Identifier.ToString()).Assign(StateMachineThisFixer.Fix(variable.Initializer.Value)); currentState.Add(Cs.Express(assignment)); var symbol = (ILocalSymbol)ModelExtensions.GetDeclaredSymbol(semanticModel, variable); // Hoist the variable into a field node = (ForStatementSyntax)HoistVariable(node, variable.Identifier, SyntaxFactory.ParseTypeName(symbol.Type.GetFullName())); } } foreach (var initializer in node.Initializers) { currentState.Add(Cs.Express(StateMachineThisFixer.Fix(initializer))); } MaybeCreateNewState(); var nextState = GetNextState(node); var iterationState = currentState; // postState determines where the flow goes after each iteration. If there are incrementors, it goes // to a state that handles the incrementing, then goes back to the iteration state. Otherwise, the // iteration state just points back to itself like the while loop. State postState; if (node.Incrementors.Any()) { postState = new State(this); postState.NextState = iterationState; foreach (var incrementor in node.Incrementors) { postState.Add(Cs.Express(StateMachineThisFixer.Fix(incrementor))); } Close(postState); } else { postState = iterationState; } iterationState.NextState = nextState; iterationState.BreakState = nextState; var forStatement = SyntaxFactory.WhileStatement(StateMachineThisFixer.Fix(node.Condition) ?? Cs.True(), SyntaxFactory.Block(CaptureState(node.Statement, postState, nextState))); iterationState.Statements.Add(StateMachineThisFixer.Fix(forStatement)); Close(iterationState); currentState = nextState; } }
public override void VisitDoStatement(DoStatementSyntax node) { if (!YieldChecker.HasSpecialStatement(node)) { currentState.Add(StateMachineThisFixer.Fix(node)); } else { MaybeCreateNewState(); var nextState = GetNextState(node); var conditionState = new State(this) { BreakState = nextState }; var iterationState = currentState; conditionState.Add(Cs.If(StateMachineThisFixer.Fix(node.Condition), ChangeState(iterationState), ChangeState(nextState))); conditionState.Add(GotoTop()); SetClosed(conditionState); iterationState.NextState = conditionState; node.Statement.Accept(this); if (currentState != nextState) { Close(currentState); } currentState = nextState; } }