Пример #1
0
        public override void VisitSwitchStatement(SwitchStatementSyntax node)
        {
            if (!YieldChecker.HasSpecialStatement(node))
            {
                currentState.Add(node);
            }
            else
            {
                var nextState   = GetNextState(node);
                var switchState = currentState;
                switchState.NextState = nextState;

                var switchStatement = Cs.Switch(node.Expression);

                foreach (var section in node.Sections.Select(x => x.WithStatements(SyntaxFactory.List(x.Statements.Select(BreakStatementStripper.StripStatements)))))
                {
                    switchStatement = switchStatement.AddSections(Cs.Section(section.Labels, CaptureState(section, switchState.NextState, nextState)));
                }

                switchState.Statements.Add(switchStatement);

                Close(switchState);

                currentState = nextState;
            }
        }
Пример #2
0
        public override void VisitTryStatement(TryStatementSyntax node)
        {
            if (!YieldChecker.HasSpecialStatement(node))
            {
                currentState.Add(StateMachineThisFixer.Fix(node));
            }
            else
            {
                if (node.Catches.Any())
                {
                    throw new Exception("Yield statements cannot be contained inside try/catch blocks.");
                }

                var nextState = GetNextState(node);

                MaybeCreateNewState();

                var tryState = currentState;
                tryState.NextState = nextState;

                var exceptionName = SyntaxFactory.Identifier("_ex" + exceptionNameCounter++);
                var finallyState  = new State(this)
                {
                    NextState = nextState, BreakState = nextState
                };
                foreach (var finallyStatement in node.Finally.Block.Statements)
                {
                    finallyState.Add(finallyStatement);
                }
                finallyState.Add(Cs.If(Cs.This().Member(exceptionName).NotEqualTo(Cs.Null()), Cs.Throw(Cs.This().Member(exceptionName))));
                Close(finallyState);

                node = (TryStatementSyntax)HoistVariable(node, exceptionName, SyntaxFactory.ParseTypeName("System.Exception"));

                tryState.NextState = finallyState;
                tryState.Germ      = yieldState =>
                {
                    var gotoFinally = SyntaxFactory.Block(
                        Cs.Express(Cs.This().Member(exceptionName).Assign(SyntaxFactory.IdentifierName(exceptionName))),
                        ChangeState(finallyState),
                        GotoTop()
                        );

                    var statements = yieldState.Statements.ToArray();
                    yieldState.Statements.Clear();
                    yieldState.Statements.Add(Cs.Try().WithBlock(Cs.Block(statements)).WithCatches(SyntaxFactory.List(new[] {
                        SyntaxFactory.CatchClause(SyntaxFactory.CatchDeclaration(SyntaxFactory.ParseTypeName("System.Exception"), exceptionName), null, gotoFinally)
                    })));
                };

                node.Block.Accept(this);

                if (!tryState.IsClosed)
                {
                    CloseTo(tryState, finallyState);
                }

                currentState = nextState;
            }
        }
Пример #3
0
        public override void VisitDoStatement(DoStatementSyntax node)
        {
            if (!YieldChecker.HasSpecialStatement(node))
            {
                currentState.Add(StateMachineThisFixer.Fix(node));
            }
            else
            {
                MaybeCreateNewState();

                var nextState = GetNextState(node);

                var conditionState = new State(this)
                {
                    BreakState = nextState
                };

                var iterationState = currentState;

                conditionState.Add(Cs.If(StateMachineThisFixer.Fix(node.Condition), ChangeState(iterationState), ChangeState(nextState)));
                conditionState.Add(GotoTop());
                SetClosed(conditionState);
                iterationState.NextState = conditionState;

                node.Statement.Accept(this);
                if (currentState != nextState)
                {
                    Close(currentState);
                }

                currentState = nextState;
            }
        }
Пример #4
0
        public void GenerateStates()
        {
            var lastState = new State(this);

            lastState.Statements.Add(ReturnOutOfState());

            currentState = new State(this)
            {
                NextState = lastState
            };
            node.Accept(this);

            // Post-process goto statements
            if (labelStates.Any())
            {
                var gotoSubstituter = new GotoSubstituter(compilation, labelStates);
                foreach (var state in states)
                {
                    state.Statements = state.Statements.Select(x => (StatementSyntax)x.Accept(gotoSubstituter)).ToList();
                }
            }

            var lastStatement = states.Last().Statements.LastOrDefault();

            if (lastStatement == null || (!(lastStatement is BreakStatementSyntax) && !(lastStatement is ReturnStatementSyntax)))
            {
                states.Last().Statements.Add(Cs.Break());
            }
        }
Пример #5
0
        public override void VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
        {
            var semanticModel = compilation.GetSemanticModel(node.SyntaxTree);

            // Convert the variable declaration to an assignment expression statement
            foreach (var variable in node.Declaration.Variables)
            {
                if (variable.Initializer != null)
                {
                    var assignment = SyntaxFactory.IdentifierName(variable.Identifier.ToString()).Assign(StateMachineThisFixer.Fix(variable.Initializer.Value));
                    currentState.Add(Cs.Express(assignment));
                }

                // Hoist the variable into a field
                var symbol = (ILocalSymbol)ModelExtensions.GetDeclaredSymbol(semanticModel, variable);
                node = (LocalDeclarationStatementSyntax)HoistVariable(node, variable.Identifier, symbol.Type.ToTypeSyntax());
            }
        }
Пример #6
0
        public override void VisitYieldStatement(YieldStatementSyntax node)
        {
            var nextState = GetNextState(node);

            if (node.ReturnOrBreakKeyword.IsKind(SyntaxKind.BreakKeyword))
            {
                currentState.Add(ChangeState(nextState));
                currentState.Add(Cs.Return(Cs.False()));
            }
            else
            {
                currentState.Add(ChangeState(nextState));
                currentState.Add(Cs.Express(Cs.This().Member("Current").Assign(StateMachineThisFixer.Fix(node.Expression))));
                currentState.Add(Cs.Return(Cs.True()));
            }
            SetClosed(currentState);

            currentState = nextState;
        }
Пример #7
0
        public override void    VisitForEachStatement(ForEachStatementSyntax node)
        {
            if (!YieldChecker.HasSpecialStatement(node))
            {
                currentState.Add(StateMachineThisFixer.Fix(node));
            }
            else
            {
                // Convert the variable declaration in the for loop
                var semanticModel = compilation.GetSemanticModel(node.SyntaxTree);
                var symbol        = (ILocalSymbol)ModelExtensions.GetDeclaredSymbol(semanticModel, node);
                var targetType    = ModelExtensions.GetTypeInfo(semanticModel, node.Expression);

                // Hoist the variable into a field
                node = (ForEachStatementSyntax)HoistVariable(node, node.Identifier, symbol.Type.ToTypeSyntax());

                // Hoist the enumerator into a field
                var enumerator            = SyntaxFactory.Identifier(node.Identifier + "_enumerator");
                var genericEnumeratorType = compilation.FindType("System.Collections.Generic.IEnumerable`1");
                var elementType           = targetType.ConvertedType.GetGenericArgument(genericEnumeratorType, 0);
                var enumeratorType        = elementType == null?
                                            SyntaxFactory.ParseTypeName("System.Collections.IEnumerator") :
                                                SyntaxFactory.ParseTypeName("System.Collections.Generic.IEnumerator<" + elementType.ToDisplayString() + ">");

                node = (ForEachStatementSyntax)HoistVariable(node, enumerator, enumeratorType);
                currentState.Add(Cs.Express(SyntaxFactory.IdentifierName(enumerator).Assign(SyntaxFactory.InvocationExpression(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, StateMachineThisFixer.Fix(node.Expression), SyntaxFactory.IdentifierName("GetEnumerator"))))));

                // Mostly the same as while loop from here (key word, "mostly"; hence the lack of factoring here)
                MaybeCreateNewState();

                var nextState      = GetNextState(node);
                var iterationState = currentState;

                iterationState.NextState  = iterationState;
                iterationState.BreakState = nextState;

                var bodyBatch = new State(this, true)
                {
                    NextState = iterationState.NextState
                };

                // Assign current item
                bodyBatch.Add(Cs.Express(SyntaxFactory.IdentifierName(node.Identifier).Assign(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
                                                                                                                                   SyntaxFactory.IdentifierName(enumerator), SyntaxFactory.IdentifierName("Current")))));

                currentState = bodyBatch;
                node.Statement.Accept(this);

                var moveNext     = SyntaxFactory.InvocationExpression(SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.IdentifierName(enumerator), SyntaxFactory.IdentifierName("MoveNext")));
                var forStatement = SyntaxFactory.WhileStatement(moveNext, SyntaxFactory.Block(bodyBatch.Statements));
                iterationState.Statements.Add(forStatement);

                CloseTo(iterationState, nextState);
                if (currentState != nextState)
                {
                    Close(currentState);
                }

                currentState = nextState;
            }
        }
Пример #8
0
        public override void VisitForStatement(ForStatementSyntax node)
        {
            if (!YieldChecker.HasSpecialStatement(node))
            {
                currentState.Add(StateMachineThisFixer.Fix(node));
            }
            else
            {
                var semanticModel = compilation.GetSemanticModel(node.SyntaxTree);

                // Convert the variable declarations in the for loop
                if (node.Declaration != null)
                {
                    foreach (var variable in node.Declaration.Variables.Where(x => x.Initializer != null))
                    {
                        var assignment = SyntaxFactory.IdentifierName(variable.Identifier.ToString()).Assign(StateMachineThisFixer.Fix(variable.Initializer.Value));
                        currentState.Add(Cs.Express(assignment));

                        var symbol = (ILocalSymbol)ModelExtensions.GetDeclaredSymbol(semanticModel, variable);

                        // Hoist the variable into a field
                        node = (ForStatementSyntax)HoistVariable(node, variable.Identifier, SyntaxFactory.ParseTypeName(symbol.Type.GetFullName()));
                    }
                }
                foreach (var initializer in node.Initializers)
                {
                    currentState.Add(Cs.Express(StateMachineThisFixer.Fix(initializer)));
                }

                MaybeCreateNewState();

                var nextState      = GetNextState(node);
                var iterationState = currentState;

                // postState determines where the flow goes after each iteration.  If there are incrementors, it goes
                // to a state that handles the incrementing, then goes back to the iteration state.  Otherwise, the
                // iteration state just points back to itself like the while loop.
                State postState;
                if (node.Incrementors.Any())
                {
                    postState           = new State(this);
                    postState.NextState = iterationState;

                    foreach (var incrementor in node.Incrementors)
                    {
                        postState.Add(Cs.Express(StateMachineThisFixer.Fix(incrementor)));
                    }
                    Close(postState);
                }
                else
                {
                    postState = iterationState;
                }
                iterationState.NextState  = nextState;
                iterationState.BreakState = nextState;

                var forStatement = SyntaxFactory.WhileStatement(StateMachineThisFixer.Fix(node.Condition) ?? Cs.True(), SyntaxFactory.Block(CaptureState(node.Statement, postState, nextState)));
                iterationState.Statements.Add(StateMachineThisFixer.Fix(forStatement));

                Close(iterationState);

                currentState = nextState;
            }
        }
Пример #9
0
        public ClassDeclarationSyntax CreateEnumerator()
        {
            var members = new List <MemberDeclarationSyntax>();
            FieldDeclarationSyntax thisField = null;

            if (!method.IsStatic)
            {
                thisField = Cs.Field(method.ContainingType.ToTypeSyntax(), "_this");
                members.Add(thisField);
            }

            var stateGenerator = new YieldStateGenerator(compilation, node);

            stateGenerator.GenerateStates();
            var states = stateGenerator.States;

            var isStartedField = Cs.Field(Cs.Bool(), isStarted);

            members.Add(isStartedField);

            var stateField = Cs.Field(Cs.Int(), StateGenerator.state);

            members.Add(stateField);


            ITypeSymbol elementType = null;

            if (method is IMethodSymbol)
            {
                elementType = ((INamedTypeSymbol)(method as IMethodSymbol).ReturnType).TypeArguments.FirstOrDefault();
            }
            else if (method is IPropertySymbol)
            {
                elementType = ((INamedTypeSymbol)(method as IPropertySymbol).Type).TypeArguments.FirstOrDefault();
            }

            if (elementType == null)
            {
                elementType = Context.Object;
            }

            var currentProperty = Cs.Property(elementType.ToTypeSyntax(), "Current");

            members.Add(currentProperty);

            if (node is MethodDeclarationSyntax)
            {
                foreach (var parameter in (node as MethodDeclarationSyntax).ParameterList.Parameters)
                {
                    var parameterField = Cs.Field(parameter.Type, parameter.Identifier);
                    members.Add(parameterField);
                }
            }
            foreach (var variable in stateGenerator.HoistedVariables)
            {
                var variableField = Cs.Field(variable.Item2, variable.Item1);
                members.Add(variableField);
            }

            var className = method.GetYieldClassName();

            var constructorParameters = new List <ParameterSyntax>();

            if (!method.IsStatic)
            {
                constructorParameters.Add(SyntaxFactory.Parameter(SyntaxFactory.Identifier("_this")).WithType(thisField.Declaration.Type));
            }
            if (node is MethodDeclarationSyntax)
            {
                constructorParameters.AddRange((node as MethodDeclarationSyntax).ParameterList.Parameters.Select(x => SyntaxFactory.Parameter(x.Identifier).WithType(x.Type)));
            }

            var constructor = SyntaxFactory.ConstructorDeclaration(className)
                              .AddModifiers(Cs.Public())
                              .WithParameterList(constructorParameters.ToArray())
                              .WithBody(
                SyntaxFactory.Block(
                    // Assign fields
                    constructorParameters.Select(x => Cs.Express(Cs.Assign(Cs.This().Member(x.Identifier), SyntaxFactory.IdentifierName(x.Identifier))))
                    )
                .AddStatements(
                    Cs.Express(Cs.Assign(Cs.This().Member(StateGenerator.state), Cs.Integer(1)))
                    )
                );

            members.Add(constructor);


            var ienumerable_g = SyntaxFactory.ParseTypeName("System.Collections.Generic.IEnumerable<" + elementType.ToDisplayString() + ">");
            var ienumerator_g = SyntaxFactory.ParseTypeName("System.Collections.Generic.IEnumerator<" + elementType.ToDisplayString() + ">");

            var ienumerable = SyntaxFactory.ParseTypeName("System.Collections.IEnumerable");
            var ienumerator = SyntaxFactory.ParseTypeName("System.Collections.IEnumerator");
            // IEnumerator IEnumerable.GetEnumerator()
            //{
            //    return GetEnumerator();
            //}

            var iegetEnumerator = SyntaxFactory.MethodDeclaration(ienumerator, "GetEnumerator")
                                  .AddModifiers(Cs.Public())
                                  .WithExplicitInterfaceSpecifier(SyntaxFactory.ExplicitInterfaceSpecifier(SyntaxFactory.ParseName("System.Collections.IEnumerable")))
                                  .WithBody(Cs.Block(
                                                Cs.Return(Cs.This().Member("GetEnumerator").Invoke())));

            members.Add(iegetEnumerator);

            // public void Dispose()
            //{
            //}


            var dispose = SyntaxFactory.MethodDeclaration(Context.Void.ToTypeSyntax(), "Dispose")
                          .AddModifiers(Cs.Public())
                          .WithBody(Cs.Block());

            members.Add(dispose);

            // public void Reset()
            // {
            //     throw new NotImplementedException();
            // }

            var reset = SyntaxFactory.MethodDeclaration(Context.Void.ToTypeSyntax(), "Reset")
                        .AddModifiers(Cs.Public())
                        .WithBody(Cs.Block(Cs.Throw(Cs.New(Context.Compilation.FindType("System.NotImplementedException").ToTypeSyntax()))));

            members.Add(reset);


            // object IEnumerator.Current
            //{
            //    get { return Current; }
            //}
            //IEnumerator
            var iecurrent =
                Cs.Property(Context.Object.ToTypeSyntax(), "Current", true, false, Cs.Block(
                                Cs.Return(Cs.This().Member("Current"))
                                )).WithExplicitInterfaceSpecifier(SyntaxFactory.ExplicitInterfaceSpecifier(SyntaxFactory.ParseName("System.Collections.IEnumerator")));


            members.Add(iecurrent);

            // Generate the GetEnumerator method, which looks something like:
            // var $isStartedLocal = $isStarted;
            // $isStarted = true;
            // if ($isStartedLocal)
            //     return this.Clone().GetEnumerator();
            // else
            //     return this;
            var getEnumerator = SyntaxFactory.MethodDeclaration(ienumerator_g, "GetEnumerator")
                                .AddModifiers(Cs.Public())
                                .WithBody(Cs.Block(
                                              Cs.Local(isStartedLocal, SyntaxFactory.IdentifierName(isStarted)),
                                              Cs.Express(SyntaxFactory.IdentifierName(isStarted).Assign(Cs.True())),
                                              Cs.If(
                                                  SyntaxFactory.IdentifierName(isStartedLocal),
                                                  Cs.Return(Cs.This().Member("Clone").Invoke().Member("GetEnumerator").Invoke()),
                                                  Cs.Return(Cs.This()))));

            members.Add(getEnumerator);

            // Generate the MoveNext method, which looks something like:
            // $top:
            // while (true)
            // {
            //     switch (state)
            //     {
            //         case 0: ...
            //         case 1: ...
            //     }
            // }
            var moveNextBody = SyntaxFactory.LabeledStatement("_top", Cs.While(Cs.True(),
                                                                               Cs.Switch(Cs.This().Member(StateGenerator.state), states.Select((x, i) =>
                                                                                                                                               Cs.Section(Cs.Integer(i), x.Statements.ToArray())).ToArray())));

            var moveNext = SyntaxFactory.MethodDeclaration(Cs.Bool(), "MoveNext")
                           .AddModifiers(Cs.Public())
                           .WithBody(SyntaxFactory.Block(moveNextBody, Cs.Return(SyntaxFactory.ParseExpression("false"))));

            members.Add(moveNext);

            TypeSyntax classNameWithTypeArguments = SyntaxFactory.IdentifierName(className);

            if (method is IMethodSymbol)
            {
                if ((method as IMethodSymbol).TypeParameters.Any())
                {
                    classNameWithTypeArguments = SyntaxFactory.GenericName(
                        SyntaxFactory.Identifier(className),
                        SyntaxFactory.TypeArgumentList(SyntaxFactory.SeparatedList(
                                                           (method as IMethodSymbol).TypeParameters.Select(x => SyntaxFactory.ParseTypeName(x.Name)),
                                                           (method as IMethodSymbol).TypeParameters.Select(x => x).Skip(1).Select(_ => SyntaxFactory.Token(SyntaxKind.CommaToken)))
                                                       ));
                }
            }

            var cloneBody = Cs.Block(
                Cs.Return(classNameWithTypeArguments.New(
                              constructorParameters.Select(x => SyntaxFactory.IdentifierName(x.Identifier)).ToArray()
                              ))
                );
            var clone = SyntaxFactory.MethodDeclaration(ienumerable_g, "Clone")
                        .AddModifiers(Cs.Public())
                        .WithBody(SyntaxFactory.Block(cloneBody));

            members.Add(clone);
            //IEnumerable<T>,IEnumerator<T>
            SimpleBaseTypeSyntax[] baseTypes = new[]
            {
                SyntaxFactory.SimpleBaseType(ienumerable_g),
                SyntaxFactory.SimpleBaseType(ienumerator_g)
            };
            ClassDeclarationSyntax result = SyntaxFactory.ClassDeclaration(className).WithBaseList(baseTypes).WithMembers(members.ToArray());

            if (method is IMethodSymbol)
            {
                if ((method as IMethodSymbol).TypeParameters.Any())
                {
                    result = result.WithTypeParameterList((node as MethodDeclarationSyntax).TypeParameterList);
                }
            }

            return(result);
        }
Пример #10
0
 protected virtual StatementSyntax ReturnOutOfState()
 {
     return(Cs.Return(Cs.False()));
 }
Пример #11
0
 public static StatementSyntax ChangeState(State newState)
 {
     return(Cs.Express(Cs.This().Member(state).Assign(Cs.Integer(newState.Index))));
 }