public override SyntaxNode VisitConditionalAccessExpression(ConditionalAccessExpressionSyntax node)
        {
            ExpressionSyntax whenNotNull = node.WhenNotNull;

            if (whenNotNull?.IsKind(SyntaxKind.MemberBindingExpression) == true)
            {
                var memberBinding = (MemberBindingExpressionSyntax)whenNotNull;

                SimpleNameSyntax name = memberBinding.Name;

                if (name != null &&
                    _textSpans.Contains(name.Span))
                {
                    _textSpans.Remove(node.Span);

                    var expression = (ExpressionSyntax)base.Visit(node.Expression);

                    InvocationExpressionSyntax invocation = InvocationExpression(
                        memberBinding.WithName(IdentifierName(_methodName).WithLeadingTrivia(name.GetLeadingTrivia())),
                        ArgumentList().WithTrailingTrivia(memberBinding.GetTrailingTrivia()));

                    return(node
                           .WithExpression(expression)
                           .WithWhenNotNull(invocation));
                }
            }

            return(base.VisitConditionalAccessExpression(node));
        }
 public override ExpressionSyntax VisitConditionalAccessExpression(ConditionalAccessExpressionSyntax node)
 {
     return(node
            .WithExpression(
                Visit(node.Expression))
            .WithWhenNotNull(
                Visit(node.WhenNotNull)));
 }
Ejemplo n.º 3
0
        private static async Task <Document> RefactorAsync(
            Document document,
            ConditionalAccessExpressionSyntax conditionalAccess,
            CancellationToken cancellationToken)
        {
            ExpressionSyntax expression = conditionalAccess.Expression;

            SemanticModel semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);

            int position = conditionalAccess.SpanStart;

            string localName = NameGenerator.Default.EnsureUniqueLocalName(DefaultNames.Variable, semanticModel, position);

            ITypeSymbol typeSymbol = semanticModel.GetTypeSymbol(expression, cancellationToken);

            TypeSyntax type = (typeSymbol?.SupportsExplicitDeclaration() == true)
                ? typeSymbol.ToMinimalTypeSyntax(semanticModel, position)
                : VarType();

            LocalDeclarationStatementSyntax localStatement = LocalDeclarationStatement(
                type,
                Identifier(localName).WithRenameAnnotation(),
                expression);

            ExpressionSyntax newExpression = ReturnCompletedTaskInsteadOfNullCodeFixProvider.CreateCompletedTaskExpression(document, conditionalAccess, semanticModel, cancellationToken);

            IfStatementSyntax ifStatement = IfStatement(
                NotEqualsExpression(IdentifierName(localName), NullLiteralExpression()),
                Block(ReturnStatement(conditionalAccess
                                      .WithExpression(IdentifierName(localName).WithTriviaFrom(conditionalAccess.Expression))
                                      .RemoveOperatorToken())),
                ElseClause(Block(ReturnStatement(newExpression))));

            SyntaxList <StatementSyntax> statements = List(new StatementSyntax[] { localStatement, ifStatement });

            SyntaxNode parent = conditionalAccess.Parent;

            if (parent is ReturnStatementSyntax returnStatement)
            {
                statements = statements.WithTriviaFrom(returnStatement);

                if (returnStatement.IsEmbedded())
                {
                    return(await document.ReplaceNodeAsync(returnStatement, Block(statements), cancellationToken).ConfigureAwait(false));
                }
                else
                {
                    return(await document.ReplaceNodeAsync(returnStatement, statements, cancellationToken).ConfigureAwait(false));
                }
            }
            else if (parent is SimpleLambdaExpressionSyntax simpleLambda)
            {
                SimpleLambdaExpressionSyntax newNode = simpleLambda
                                                       .WithBody(Block(statements))
                                                       .WithFormatterAnnotation();

                return(await document.ReplaceNodeAsync(simpleLambda, newNode, cancellationToken).ConfigureAwait(false));
            }
            else if (parent is ParenthesizedLambdaExpressionSyntax parenthesizedLambda)
            {
                ParenthesizedLambdaExpressionSyntax newNode = parenthesizedLambda
                                                              .WithBody(Block(statements))
                                                              .WithFormatterAnnotation();

                return(await document.ReplaceNodeAsync(parenthesizedLambda, newNode, cancellationToken).ConfigureAwait(false));
            }
            else
            {
                var arrowExpressionClause = (ArrowExpressionClauseSyntax)parent;

                SyntaxNode node = arrowExpressionClause.Parent;

                SyntaxNode newNode = CreateNewNode(node).WithFormatterAnnotation();

                return(await document.ReplaceNodeAsync(node, newNode, cancellationToken).ConfigureAwait(false));
            }

            SyntaxNode CreateNewNode(SyntaxNode node)
            {
                switch (node)
                {
                case MethodDeclarationSyntax methodDeclaration:
                {
                    return(methodDeclaration
                           .WithExpressionBody(null)
                           .WithBody(Block(statements)));
                }

                case LocalFunctionStatementSyntax localFunction:
                {
                    return(localFunction
                           .WithExpressionBody(null)
                           .WithBody(Block(statements)));
                }

                case PropertyDeclarationSyntax propertyDeclaration:
                {
                    return(propertyDeclaration
                           .WithExpressionBody(null)
                           .WithSemicolonToken(default(SyntaxToken))
                           .WithAccessorList(AccessorList(GetAccessorDeclaration(Block(statements)))));
                }

                case IndexerDeclarationSyntax indexerDeclaration:
                {
                    return(indexerDeclaration
                           .WithExpressionBody(null)
                           .WithSemicolonToken(default(SyntaxToken))
                           .WithAccessorList(AccessorList(GetAccessorDeclaration(Block(statements)))));
                }

                case AccessorDeclarationSyntax accessorDeclaration:
                {
                    return(accessorDeclaration
                           .WithExpressionBody(null)
                           .WithSemicolonToken(default(SyntaxToken))
                           .WithBody(Block(statements)));
                }

                default:
                {
                    Debug.Fail(node.Kind().ToString());
                    return(node);
                }
                }
            }
        }
        private T TransformConditionalAccessToConditionalExpressions <T>(
            T node,
            SimpleNameSyntax nameNode,
            IFunctionReferenceAnalyzationResult funReferenceResult,
            ITypeTransformationMetadata typeMetadata,
            ConditionalAccessExpressionSyntax conditionalAccessNode,
            InvocationExpressionSyntax invokeNode) where T : SyntaxNode
        {
            // TODO: we should check the async symbol instead
            var returnType  = funReferenceResult.ReferenceSymbol.ReturnType;
            var type        = returnType.CreateTypeSyntax();
            var canSkipCast = !returnType.IsValueType && !returnType.IsNullable();

            if (returnType.IsValueType && !returnType.IsNullable())
            {
                type = NullableType(type);
            }

            ExpressionSyntax whenNotNullNode = null;

            if (invokeNode.Parent is MemberAccessExpressionSyntax memberAccessParent)
            {
                whenNotNullNode = conditionalAccessNode.WhenNotNull
                                  .ReplaceNode(memberAccessParent,
                                               MemberBindingExpression(Token(SyntaxKind.DotToken), memberAccessParent.Name));
            }
            else if (invokeNode.Parent is ElementAccessExpressionSyntax elementAccessParent)
            {
                whenNotNullNode = conditionalAccessNode.WhenNotNull
                                  .ReplaceNode(elementAccessParent,
                                               ElementBindingExpression(elementAccessParent.ArgumentList));
            }

            var             valueNode         = conditionalAccessNode.Expression.WithoutTrivia();
            StatementSyntax variableStatement = null;
            BlockSyntax     statementBlock    = null;
            var             statementIndex    = 0;

            // We have to save the value in a variable when the expression is an invocation, index accessor or property in
            // order to prevent double calls
            // TODO: find a more robust solution
            if (!(conditionalAccessNode.Expression is SimpleNameSyntax simpleName) || char.IsUpper(simpleName.ToString()[0]))
            {
                var statement = (StatementSyntax)conditionalAccessNode.Ancestors().FirstOrDefault(o => o is StatementSyntax);
                if (statement == null || !(statement.Parent is BlockSyntax block))
                {
                    // TODO: convert arrow method/property/function to a normal one
                    // TODO: convert to block if there is no block
                    throw new NotSupportedException(
                              $"Arrow method with null-conditional is not supported. Node: {conditionalAccessNode}");
                }

                var leadingTrivia = statement.GetLeadingTrivia();
                var fnName        = nameNode.Identifier.ValueText;
                statementIndex = block.Statements.IndexOf(statement);
                statementBlock = block;
                // TODO: handle name collisions
                var variableName = $"{char.ToLowerInvariant(fnName[0])}{fnName.Substring(1)}{statementIndex}";
                variableStatement = LocalDeclarationStatement(
                    VariableDeclaration(
                        IdentifierName(Identifier(leadingTrivia, "var", TriviaList(Space))),
                        SingletonSeparatedList(
                            VariableDeclarator(
                                Identifier(TriviaList(), variableName, TriviaList(Space)))
                            .WithInitializer(
                                EqualsValueClause(valueNode)
                                .WithEqualsToken(Token(TriviaList(), SyntaxKind.EqualsToken, TriviaList(Space)))
                                )
                            )))
                                    .WithSemicolonToken(Token(TriviaList(), SyntaxKind.SemicolonToken, TriviaList(typeMetadata.EndOfLineTrivia)));

                valueNode = IdentifierName(variableName);
            }

            var invocationAnnotation = Guid.NewGuid().ToString();
            var nullNode             = LiteralExpression(
                SyntaxKind.NullLiteralExpression,
                Token(TriviaList(), SyntaxKind.NullKeyword, TriviaList(Space)));
            var ifNullCondition = BinaryExpression(
                SyntaxKind.EqualsExpression,
                valueNode.WithTrailingTrivia(TriviaList(Space)),
                nullNode)
                                  .WithOperatorToken(
                Token(TriviaList(), SyntaxKind.EqualsEqualsToken, TriviaList(Space)));

            ExpressionSyntax wrappedNode = ParenthesizedExpression(
                ConditionalExpression(
                    ifNullCondition,
                    canSkipCast
                                                        ? (ExpressionSyntax)nullNode
                                                        : CastExpression(
                        Token(SyntaxKind.OpenParenToken),
                        type,
                        Token(TriviaList(), SyntaxKind.CloseParenToken, TriviaList(Space)),
                        nullNode),
                    InvocationExpression(
                        MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
                                               valueNode,
                                               ((MemberBindingExpressionSyntax)invokeNode.Expression).Name))
                    .WithAdditionalAnnotations(new SyntaxAnnotation(invocationAnnotation))
                    .WithArgumentList(invokeNode.ArgumentList.WithoutTrailingTrivia())
                    )
                .WithColonToken(Token(TriviaList(), SyntaxKind.ColonToken, TriviaList(Space)))
                .WithQuestionToken(Token(TriviaList(), SyntaxKind.QuestionToken, TriviaList(Space)))
                );

            if (whenNotNullNode != null)
            {
                wrappedNode = conditionalAccessNode
                              .WithExpression(wrappedNode)
                              .WithWhenNotNull(whenNotNullNode);
            }

            wrappedNode = wrappedNode.WithTriviaFrom(conditionalAccessNode);
            invokeNode  = (InvocationExpressionSyntax)wrappedNode.GetAnnotatedNodes(invocationAnnotation).First();
            wrappedNode = wrappedNode.ReplaceNode(invokeNode, invokeNode.AddAwait(_configuration.ConfigureAwaitArgument));
            if (statementBlock != null)
            {
                var newBlock = statementBlock.ReplaceNode(conditionalAccessNode, wrappedNode);
                newBlock = newBlock.WithStatements(newBlock.Statements.Insert(statementIndex, variableStatement));

                return(node.ReplaceNode(statementBlock, newBlock));
            }

            return(node.ReplaceNode(conditionalAccessNode, wrappedNode));
        }