public SyntaxNode TransformFunctionReference(SyntaxNode node, IFunctionAnalyzationResult funcResult,
                                              IFunctionReferenceAnalyzationResult funcReferenceResult, INamespaceTransformationMetadata namespaceMetadata)
 {
     if (funcReferenceResult is IBodyFunctionReferenceAnalyzationResult bodyFunctionReference &&
         (bodyFunctionReference.PassCancellationToken || funcResult.GetMethodOrAccessor().AddCancellationTokenGuards))
     {
         // Mark the invocation node in order to add the OperationCanceledException catch block only if there is at least one async invocation
         // with a cancellation token passed as an argument or there will be a cancellation token guard added
         return(node.WithAdditionalAnnotations(new SyntaxAnnotation(Annotations.AsyncCallWithTokenOrGuard)));
     }
     return(null);
 }
        private T TransformFunctionReference <T>(T node, IFunctionAnalyzationResult funcResult, FunctionReferenceTransformationResult transfromReference,
                                                 ITypeTransformationMetadata typeMetadata,
                                                 INamespaceTransformationMetadata namespaceMetadata)
            where T : SyntaxNode
        {
            var nameNode                = node.GetAnnotatedNodes(transfromReference.Annotation).OfType <SimpleNameSyntax>().First();
            var funReferenceResult      = transfromReference.AnalyzationResult;
            var bodyFuncReferenceResult = funReferenceResult as IBodyFunctionReferenceAnalyzationResult;
            var newNameNode             = nameNode
                                          .WithIdentifier(Identifier(funReferenceResult.AsyncCounterpartName))
                                          .WithTriviaFrom(nameNode);

            transfromReference.Transformed = newNameNode;

            var cancellationTokenParamName = funcResult.GetMethodOrAccessor().CancellationTokenRequired ? "cancellationToken" : null;             // TODO: remove

            // If we have a cref change the name to the async counterpart and add/update arguments
            if (bodyFuncReferenceResult == null)
            {
                if (funReferenceResult.IsCref)
                {
                    var crefNode  = (NameMemberCrefSyntax)nameNode.Parent;
                    var paramList = new List <CrefParameterSyntax>();
                    // If the cref has already the parameters set then use them
                    if (crefNode.Parameters != null)
                    {
                        paramList.AddRange(crefNode.Parameters.Parameters);
                        // If the external async counterpart has a cancellation token, add it
                        if (funReferenceResult.AsyncCounterpartFunction == null &&
                            funReferenceResult.ReferenceSymbol.Parameters.Length <
                            funReferenceResult.AsyncCounterpartSymbol.Parameters.Length)
                        {
                            paramList.Add(CrefParameter(IdentifierName(nameof(CancellationToken))));
                        }
                    }
                    else
                    {
                        // We have to add the parameters to avoid ambiguity
                        var asyncSymbol = funReferenceResult.AsyncCounterpartSymbol;
                        paramList.AddRange(asyncSymbol.Parameters
                                           .Select(o => CrefParameter(o.Type
                                                                      .CreateTypeSyntax(true, namespaceMetadata.AnalyzationResult.IsIncluded(o.Type.ContainingNamespace?.ToString())))));
                    }

                    // If the async counterpart is internal and a token is required add a token parameter
                    if (funReferenceResult.AsyncCounterpartFunction?.GetMethodOrAccessor()?.CancellationTokenRequired == true)
                    {
                        paramList.Add(CrefParameter(IdentifierName(nameof(CancellationToken))));
                    }

                    node = node.ReplaceNestedNodes(
                        crefNode.Parent as QualifiedCrefSyntax,
                        crefNode,
                        crefNode
                        .ReplaceNode(nameNode, newNameNode)
                        .WithParameters(CrefParameterList(SeparatedList(paramList))),
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithContainer(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(rootNode.Container))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }
                else if (funReferenceResult.IsNameOf)
                {
                    node = node.ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName).WithTriviaFrom(rootNode.Expression))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }
                return(node);
            }
            // If we have a method passed as an argument we need to check if we have to wrap it inside a function
            if (bodyFuncReferenceResult.AsyncDelegateArgument != null)
            {
                if (bodyFuncReferenceResult.WrapInsideFunction)
                {
                    // TODO: move to analyze step
                    var  argumentNode  = nameNode.Ancestors().OfType <ArgumentSyntax>().First();
                    var  delReturnType = (INamedTypeSymbol)bodyFuncReferenceResult.AsyncDelegateArgument.ReturnType;
                    var  returnType    = bodyFuncReferenceResult.AsyncCounterpartSymbol.ReturnType;
                    bool returnTypeMismatch;
                    if (bodyFuncReferenceResult.ReferenceFunction != null)
                    {
                        var refMethod = bodyFuncReferenceResult.ReferenceFunction as IMethodAnalyzationResult;
                        if (refMethod != null && refMethod.PreserveReturnType)
                        {
                            returnTypeMismatch = !delReturnType.Equals(returnType);   // TODO Generics
                        }
                        else if (delReturnType.IsGenericType)                         // Generic Task
                        {
                            returnTypeMismatch = delReturnType.TypeArguments.First().IsAwaitRequired(returnType);
                        }
                        else
                        {
                            returnTypeMismatch = delReturnType.IsAwaitRequired(returnType);
                        }
                    }
                    else
                    {
                        returnTypeMismatch = !delReturnType.Equals(returnType);                         // TODO Generics
                    }

                    var newArgumentExpression = argumentNode.Expression
                                                .ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        )
                                                .WrapInsideFunction(bodyFuncReferenceResult.AsyncDelegateArgument, returnTypeMismatch,
                                                                    namespaceMetadata.TaskConflict,
                                                                    invocation => invocation.AddCancellationTokenArgumentIf(cancellationTokenParamName, bodyFuncReferenceResult));
                    node = node
                           .ReplaceNode(argumentNode.Expression, newArgumentExpression);
                }
                else
                {
                    node = node.ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }
                return(node);
            }

            InvocationExpressionSyntax invokeNode = null;
            var isAccessor = bodyFuncReferenceResult.ReferenceSymbol.IsAccessor();

            if (!isAccessor && funReferenceResult.ReferenceNode.IsKind(SyntaxKind.InvocationExpression))
            {
                invokeNode = nameNode.Ancestors().OfType <InvocationExpressionSyntax>().First();
            }

            if (!bodyFuncReferenceResult.AwaitInvocation)
            {
                // An arrow method does not have a statement
                var statement = nameNode.Ancestors().OfType <StatementSyntax>().FirstOrDefault();
                var statementInParentFunction = nameNode.Ancestors().TakeWhile(o => !o.Equals(statement)).Any(o => o.IsFunction());
                var newNode = (SyntaxNode)statement ?? node;

                if (invokeNode != null)
                {
                    newNode = newNode.ReplaceNestedNodes(
                        invokeNode,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode
                                                                          .AddCancellationTokenArgumentIf(cancellationTokenParamName, bodyFuncReferenceResult),
                                                                          funcResult, funReferenceResult, namespaceMetadata,
                                                                          (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression)))
                        );
                }
                else if (isAccessor)
                {
                    newNode = ConvertAccessor(newNode, nameNode, newNameNode, cancellationTokenParamName, bodyFuncReferenceResult,
                                              invNode => UpdateTypeAndRunReferenceTransformers(invNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                                               (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression))));
                }
                else
                {
                    newNode = newNode.ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName).WithTriviaFrom(rootNode.Expression))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }

                if (statement != null && !statement.IsKind(SyntaxKind.LocalFunctionStatement))
                {
                    // Skip adding return statement for arrow functions
                    if (bodyFuncReferenceResult.UseAsReturnValue && !statementInParentFunction)
                    {
                        newNode = ((StatementSyntax)newNode).ToReturnStatement();
                    }
                    node = node
                           .ReplaceNode(statement, newNode);
                }
                else
                {
                    node = (T)newNode;
                }
            }
            else
            {
                // We need to annotate the invocation node because of the AddAwait method as it needs the parent node
                var invokeAnnotation = Guid.NewGuid().ToString();
                if (isAccessor)
                {
                    node = ConvertAccessor(node, nameNode, newNameNode, cancellationTokenParamName, bodyFuncReferenceResult, invNode =>
                                           UpdateTypeAndRunReferenceTransformers(invNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                                 (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression)))
                                           .WithAdditionalAnnotations(new SyntaxAnnotation(invokeAnnotation))
                                           );
                }
                else
                {
                    node = node.ReplaceNestedNodes(
                        invokeNode,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode
                                                                          .AddCancellationTokenArgumentIf(cancellationTokenParamName, bodyFuncReferenceResult),
                                                                          funcResult, funReferenceResult, namespaceMetadata,
                                                                          (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression)))
                        .WithAdditionalAnnotations(new SyntaxAnnotation(invokeAnnotation))
                        );
                }

                invokeNode = node.GetAnnotatedNodes(invokeAnnotation).OfType <InvocationExpressionSyntax>().First();

                // Check if the invocation has a ?.
                var conditionalAccessNode = invokeNode.Ancestors()
                                            .TakeWhile(o => !(o is StatementSyntax))
                                            .OfType <ConditionalAccessExpressionSyntax>()
                                            .FirstOrDefault(o => o.WhenNotNull.Contains(invokeNode));
                if (conditionalAccessNode != null)                 // ?. syntax
                {
                    // We have to find out which strategy to use, if we have a non assignable expression, we are force to use if statements
                    // otherwise a ternary condition will be used
                    if (!conditionalAccessNode.Parent.IsKind(SyntaxKind.ExpressionStatement) || !invokeNode.Equals(conditionalAccessNode.WhenNotNull))
                    {
                        node = TransformConditionalAccessToConditionalExpressions(node, nameNode, funReferenceResult, typeMetadata,
                                                                                  conditionalAccessNode, invokeNode);
                    }
                    else
                    {
                        node = TransformConditionalAccessToIfStatements(node, nameNode, typeMetadata, conditionalAccessNode, invokeNode);
                    }
                }
                else
                {
                    node = node.ReplaceNode(invokeNode, invokeNode.AddAwait(_configuration.ConfigureAwaitArgument));
                }
            }
            return(node);
        }
Пример #3
0
        private T TransformFunctionReference <T>(T node, IFunctionAnalyzationResult funcResult, FunctionReferenceTransformationResult transfromReference,
                                                 ITypeTransformationMetadata typeMetadata,
                                                 INamespaceTransformationMetadata namespaceMetadata)
            where T : SyntaxNode
        {
            var nameNode                = node.GetAnnotatedNodes(transfromReference.Annotation).OfType <SimpleNameSyntax>().First();
            var funReferenceResult      = transfromReference.AnalyzationResult;
            var bodyFuncReferenceResult = funReferenceResult as IBodyFunctionReferenceAnalyzationResult;
            var newNameNode             = nameNode
                                          .WithIdentifier(Identifier(funReferenceResult.AsyncCounterpartName))
                                          .WithTriviaFrom(nameNode);

            transfromReference.Transformed = newNameNode;

            var cancellationTokenParamName = funcResult.GetMethodOrAccessor().CancellationTokenRequired ? "cancellationToken" : null;             // TODO: remove

            // If we have a cref change the name to the async counterpart and add/update arguments
            if (bodyFuncReferenceResult == null)
            {
                if (funReferenceResult.IsCref)
                {
                    var crefNode  = (NameMemberCrefSyntax)nameNode.Parent;
                    var paramList = new List <CrefParameterSyntax>();
                    // If the cref has already the parameters set then use them
                    if (crefNode.Parameters != null)
                    {
                        paramList.AddRange(crefNode.Parameters.Parameters);
                        // If the external async counterpart has a cancellation token, add it
                        if (funReferenceResult.AsyncCounterpartFunction == null &&
                            funReferenceResult.ReferenceSymbol.Parameters.Length <
                            funReferenceResult.AsyncCounterpartSymbol.Parameters.Length)
                        {
                            paramList.Add(CrefParameter(IdentifierName(nameof(CancellationToken))));
                        }
                    }
                    else
                    {
                        // We have to add the parameters to avoid ambiguity
                        var asyncSymbol = funReferenceResult.AsyncCounterpartSymbol;
                        paramList.AddRange(asyncSymbol.Parameters
                                           .Select(o => CrefParameter(o.Type
                                                                      .CreateTypeSyntax(true, namespaceMetadata.AnalyzationResult.IsIncluded(o.Type.ContainingNamespace?.ToString())))));
                    }

                    // If the async counterpart is internal and a token is required add a token parameter
                    if (funReferenceResult.AsyncCounterpartFunction?.GetMethodOrAccessor()?.CancellationTokenRequired == true)
                    {
                        paramList.Add(CrefParameter(IdentifierName(nameof(CancellationToken))));
                    }

                    node = node.ReplaceNestedNodes(
                        crefNode.Parent as QualifiedCrefSyntax,
                        crefNode,
                        crefNode
                        .ReplaceNode(nameNode, newNameNode)
                        .WithParameters(CrefParameterList(SeparatedList(paramList))),
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithContainer(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(rootNode.Container))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }
                else if (funReferenceResult.IsNameOf)
                {
                    node = node.ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName).WithTriviaFrom(rootNode.Expression))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }
                return(node);
            }
            // If we have a method passed as an argument we need to check if we have to wrap it inside a function
            if (bodyFuncReferenceResult.AsyncDelegateArgument != null)
            {
                if (bodyFuncReferenceResult.WrapInsideFunction)
                {
                    // TODO: move to analyze step
                    var  argumentNode  = nameNode.Ancestors().OfType <ArgumentSyntax>().First();
                    var  delReturnType = (INamedTypeSymbol)bodyFuncReferenceResult.AsyncDelegateArgument.ReturnType;
                    var  returnType    = bodyFuncReferenceResult.AsyncCounterpartSymbol.ReturnType;
                    bool returnTypeMismatch;
                    if (bodyFuncReferenceResult.ReferenceFunction != null)
                    {
                        var refMethod = bodyFuncReferenceResult.ReferenceFunction as IMethodAnalyzationResult;
                        if (refMethod != null && refMethod.PreserveReturnType)
                        {
                            returnTypeMismatch = !delReturnType.Equals(returnType);   // TODO Generics
                        }
                        else if (delReturnType.IsGenericType)                         // Generic Task
                        {
                            returnTypeMismatch = delReturnType.TypeArguments.First().IsAwaitRequired(returnType);
                        }
                        else
                        {
                            returnTypeMismatch = delReturnType.IsAwaitRequired(returnType);
                        }
                    }
                    else
                    {
                        returnTypeMismatch = !delReturnType.Equals(returnType);                         // TODO Generics
                    }

                    var newArgumentExpression = argumentNode.Expression
                                                .ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        )
                                                .WrapInsideFunction(bodyFuncReferenceResult.AsyncDelegateArgument, returnTypeMismatch,
                                                                    namespaceMetadata.TaskConflict,
                                                                    invocation => invocation.AddCancellationTokenArgumentIf(cancellationTokenParamName, bodyFuncReferenceResult));
                    node = node
                           .ReplaceNode(argumentNode.Expression, newArgumentExpression);
                }
                else
                {
                    node = node.ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }
                return(node);
            }

            InvocationExpressionSyntax invokeNode = null;
            var isAccessor = bodyFuncReferenceResult.ReferenceSymbol.IsAccessor();

            if (!isAccessor && funReferenceResult.ReferenceNode.IsKind(SyntaxKind.InvocationExpression))
            {
                invokeNode = nameNode.Ancestors().OfType <InvocationExpressionSyntax>().First();
            }

            if (!bodyFuncReferenceResult.AwaitInvocation)
            {
                // An arrow method does not have a statement
                var statement = nameNode.Ancestors().OfType <StatementSyntax>().FirstOrDefault();
                var newNode   = (SyntaxNode)statement ?? node;

                if (invokeNode != null)
                {
                    newNode = newNode.ReplaceNestedNodes(
                        invokeNode,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode
                                                                          .AddCancellationTokenArgumentIf(cancellationTokenParamName, bodyFuncReferenceResult),
                                                                          funcResult, funReferenceResult, namespaceMetadata,
                                                                          (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression)))
                        );
                }
                else if (isAccessor)
                {
                    newNode = ConvertAccessor(newNode, nameNode, newNameNode, cancellationTokenParamName, bodyFuncReferenceResult,
                                              invNode => UpdateTypeAndRunReferenceTransformers(invNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                                               (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression))));
                }
                else
                {
                    newNode = newNode.ReplaceNestedNodes(
                        nameNode.Parent as MemberAccessExpressionSyntax,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                          (type, fullName) => rootNode.WithExpression(type.CreateTypeSyntax(false, fullName).WithTriviaFrom(rootNode.Expression))),
                        childNode => RunReferenceTransformers(childNode, funcResult, funReferenceResult, namespaceMetadata)
                        );
                }

                if (statement != null && !statement.IsKind(SyntaxKind.LocalFunctionStatement))
                {
                    if (bodyFuncReferenceResult.UseAsReturnValue)
                    {
                        newNode = ((StatementSyntax)newNode).ToReturnStatement();
                    }
                    node = node
                           .ReplaceNode(statement, newNode);
                }
                else
                {
                    node = (T)newNode;
                }
            }
            else
            {
                // We need to annotate the invocation node because of the AddAwait method as it needs the parent node
                var invokeAnnotation = Guid.NewGuid().ToString();
                if (isAccessor)
                {
                    node = ConvertAccessor(node, nameNode, newNameNode, cancellationTokenParamName, bodyFuncReferenceResult, invNode =>
                                           UpdateTypeAndRunReferenceTransformers(invNode, funcResult, funReferenceResult, namespaceMetadata,
                                                                                 (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression)))
                                           .WithAdditionalAnnotations(new SyntaxAnnotation(invokeAnnotation))
                                           );
                }
                else
                {
                    node = node.ReplaceNestedNodes(
                        invokeNode,
                        nameNode,
                        newNameNode,
                        rootNode => UpdateTypeAndRunReferenceTransformers(rootNode
                                                                          .AddCancellationTokenArgumentIf(cancellationTokenParamName, bodyFuncReferenceResult),
                                                                          funcResult, funReferenceResult, namespaceMetadata,
                                                                          (memberNode, type, fullName) => memberNode.WithExpression(type.CreateTypeSyntax(true, fullName).WithTriviaFrom(memberNode.Expression)))
                        .WithAdditionalAnnotations(new SyntaxAnnotation(invokeAnnotation))
                        );
                }

                invokeNode = node.GetAnnotatedNodes(invokeAnnotation).OfType <InvocationExpressionSyntax>().First();

                var conditionalAccessNode = invokeNode.Ancestors()
                                            .TakeWhile(o => !(o is StatementSyntax))
                                            .OfType <ConditionalAccessExpressionSyntax>()
                                            .FirstOrDefault();
                if (conditionalAccessNode != null)                 // ?. syntax
                {
                    var statement = (StatementSyntax)invokeNode.Ancestors().FirstOrDefault(o => o is StatementSyntax);
                    var block     = statement?.Parent as BlockSyntax;
                    if (statement == null || block == null)
                    {
                        // TODO: convert arrow method/property/function to a normal one
                        // TODO: convert to block if there is no block
                        node = node.ReplaceNode(conditionalAccessNode,
                                                conditionalAccessNode.AddAwait(_configuration.ConfigureAwaitArgument));
                    }
                    else
                    {
                        var fnName = nameNode.Identifier.ValueText;
                        // TODO: handle name collisions
                        var variableName             = $"{char.ToLowerInvariant(fnName[0])}{fnName.Substring(1)}Task";
                        var leadingTrivia            = statement.GetLeadingTrivia();
                        var newConditionalAccessNode = ConditionalAccessExpression(
                            conditionalAccessNode.Expression,
                            invokeNode)
                                                       .WithTriviaFrom(conditionalAccessNode);
                        var localVar = LocalDeclarationStatement(
                            VariableDeclaration(
                                IdentifierName(Identifier(leadingTrivia, "var", TriviaList(Space))),
                                SingletonSeparatedList(
                                    VariableDeclarator(
                                        Identifier(TriviaList(), variableName, TriviaList(Space)))
                                    .WithInitializer(
                                        EqualsValueClause(newConditionalAccessNode.WithoutTrivia())
                                        .WithEqualsToken(Token(TriviaList(), SyntaxKind.EqualsToken, TriviaList(Space)))
                                        )
                                    )))
                                       .WithSemicolonToken(Token(TriviaList(), SyntaxKind.SemicolonToken, TriviaList(typeMetadata.EndOfLineTrivia)));
                        var index = block.Statements.IndexOf(statement);

                        var lastReturnNode = block.DescendantNodes()
                                             .Where(o => o.SpanStart >= statement.SpanStart)
                                             .OfType <ReturnStatementSyntax>()
                                             .LastOrDefault();

                        var variableAnnotation = Guid.NewGuid().ToString();
                        var newBlock           = block.ReplaceNode(conditionalAccessNode,
                                                                   conditionalAccessNode.WhenNotNull.ReplaceNode(invokeNode,
                                                                                                                 IdentifierName(variableName)
                                                                                                                 .WithAdditionalAnnotations(new SyntaxAnnotation(variableAnnotation))
                                                                                                                 .WithLeadingTrivia(conditionalAccessNode.GetLeadingTrivia())
                                                                                                                 .WithTrailingTrivia(conditionalAccessNode.GetTrailingTrivia())
                                                                                                                 ));

                        var variable = newBlock.GetAnnotatedNodes(variableAnnotation).OfType <IdentifierNameSyntax>().First();
                        newBlock = newBlock.ReplaceNode(variable, variable.AddAwait(_configuration.ConfigureAwaitArgument));

                        var ifBlock = Block()
                                      .WithOpenBraceToken(
                            Token(TriviaList(leadingTrivia), SyntaxKind.OpenBraceToken, TriviaList(typeMetadata.EndOfLineTrivia)))
                                      .WithCloseBraceToken(
                            Token(TriviaList(leadingTrivia), SyntaxKind.CloseBraceToken, TriviaList(typeMetadata.EndOfLineTrivia)))
                                      .WithStatements(new SyntaxList <StatementSyntax>()
                                                      .AddRange(newBlock.AppendIndent(typeMetadata.IndentTrivia.ToFullString()).Statements.Skip(index)));

                        var ifStatement = IfStatement(
                            BinaryExpression(
                                SyntaxKind.NotEqualsExpression,
                                IdentifierName(Identifier(TriviaList(), variableName, TriviaList(Space))),
                                LiteralExpression(SyntaxKind.NullLiteralExpression))
                            .WithOperatorToken(
                                Token(TriviaList(), SyntaxKind.ExclamationEqualsToken, TriviaList(Space))),
                            ifBlock
                            )
                                          .WithIfKeyword(
                            Token(TriviaList(leadingTrivia), SyntaxKind.IfKeyword, TriviaList(Space)))
                                          .WithCloseParenToken(
                            Token(TriviaList(), SyntaxKind.CloseParenToken, TriviaList(typeMetadata.EndOfLineTrivia)));

                        var statements = new SyntaxList <StatementSyntax>()
                                         .AddRange(newBlock.Statements.Take(index))
                                         .Add(localVar)
                                         .Add(ifStatement);
                        if (lastReturnNode?.Expression != null)
                        {
                            // Check if the variable is defined otherwise return default return type value
                            if (lastReturnNode.Expression is IdentifierNameSyntax idNode &&
                                statements.OfType <VariableDeclaratorSyntax>().All(o => o.Identifier.ToString() != idNode.Identifier.ValueText))
                            {
                                lastReturnNode = lastReturnNode.WithExpression(DefaultExpression(funcResult.GetNode().GetReturnType().WithoutTrivia()));
                            }
                            statements = statements.Add(lastReturnNode);
                        }
                        node = node.ReplaceNode(block, newBlock.WithStatements(statements));
                    }
                }
                else
                {
                    node = node.ReplaceNode(invokeNode, invokeNode.AddAwait(_configuration.ConfigureAwaitArgument));
                }
            }
            return(node);
        }
Пример #4
0
        public SyntaxNode TransformFunctionReference(SyntaxNode node, IFunctionAnalyzationResult funcResult,
                                                     IFunctionReferenceAnalyzationResult funcReferenceResult,
                                                     INamespaceTransformationMetadata namespaceMetadata)
        {
            if (!funcReferenceResult.AsyncCounterpartSymbol.Equals(_whenAllMethod))
            {
                return(node);
            }
            if (!(node is InvocationExpressionSyntax invokeNode) ||
                !(funcReferenceResult is IBodyFunctionReferenceAnalyzationResult bodyReference))
            {
                return(node);                // Cref
            }

            // Here are some examples of expected nodes
            // Task.WhenAll(Results, ReadAsync)
            // Task.WhenAll(1, 100, ReadAsync)
            // Task.WhenAll(Enumerable.Empty<string>(), ReadAsync)
            // Task.WhenAll(GetStringList(), i =>
            // {
            //	return SimpleFile.ReadAsync();
            // })

            // For Parallel.ForEach, we need to combine the two arguments into one, using the Select Linq extension e.g.
            // Task.WhenAll(Results.Select(i => ReadAsync(i))
            // For Parallel.For, we need to move the first two parameters into Enumerable.Range and then apply the same logic as for
            // Parallel.ForEach

            var actionParam  = bodyReference.ReferenceSymbol.Parameters.Last();
            var actionType   = actionParam.Type as INamedTypeSymbol;
            var actionMethod = actionType?.DelegateInvokeMethod;

            if (actionMethod == null)
            {
                throw new InvalidOperationException(
                          $"Unable to transform Parallel.{bodyReference.ReferenceSymbol.Name} to Task.WaitAll. " +
                          $"The second Parallel.{bodyReference.ReferenceSymbol.Name} argument is not a delegate, but is {actionParam.Type}");
            }
            namespaceMetadata.AddUsing("System.Linq");
            var newExpression = invokeNode.ArgumentList.Arguments.Last().Expression;

            if (!(newExpression is AnonymousFunctionExpressionSyntax))
            {
                var delArgument = bodyReference.DelegateArguments.Last();
                var cancellationTokenParamName = funcResult.GetMethodOrAccessor().CancellationTokenRequired
                                        ? "cancellationToken"
                                        : null; // TODO: find a way to not have this duplicated and fix naming colision
                newExpression = newExpression.WrapInsideFunction(actionMethod, false, namespaceMetadata.TaskConflict,
                                                                 invoke => invoke.AddCancellationTokenArgumentIf(cancellationTokenParamName, delArgument.BodyFunctionReference));
            }
            ExpressionSyntax enumerableExpression;

            if (bodyReference.ReferenceSymbol.Equals(_forMethod))
            {
                // Construct an Enumerable.Range(1, 10 - 1), where 1 and 10 are the first two arguments of Parallel.For method
                var startArg = invokeNode.ArgumentList.Arguments.First();
                enumerableExpression = InvocationExpression(
                    MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
                                           IdentifierName("Enumerable").WithLeadingTrivia(startArg.GetLeadingTrivia()),
                                           Token(SyntaxKind.DotToken),
                                           IdentifierName("Range")),
                    ArgumentList(
                        SeparatedList <ArgumentSyntax>(
                            new SyntaxNodeOrToken[]
                {
                    startArg.WithoutTrivia(),
                    Token(TriviaList(), SyntaxKind.CommaToken, TriviaList(Space)),
                    Argument(
                        BinaryExpression(SyntaxKind.SubtractExpression,
                                         invokeNode.ArgumentList.Arguments.Skip(1).First().Expression.WithoutTrivia().WithTrailingTrivia(Space),
                                         Token(TriviaList(), SyntaxKind.MinusToken, TriviaList(Space)),
                                         startArg.WithoutTrivia().Expression))
                }
                            )
                        )
                    );
            }
            else
            {
                enumerableExpression = invokeNode.ArgumentList.Arguments.First().Expression;                 // For ForEach take the first parmeter e.g. Enumerable.Range(1, 10)
            }

            var memberAccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
                                                      enumerableExpression,
                                                      Token(SyntaxKind.DotToken),
                                                      IdentifierName("Select"));
            var argument = InvocationExpression(memberAccess)
                           .WithArgumentList(
                ArgumentList(SingletonSeparatedList(Argument(newExpression.WithoutTrivia())))
                .WithCloseParenToken(Token(TriviaList(), SyntaxKind.CloseParenToken, newExpression.GetTrailingTrivia()))
                );

            return(invokeNode.WithArgumentList(
                       invokeNode.ArgumentList.WithArguments(SingletonSeparatedList(Argument(argument)))));
        }