private async Task <Document> ConvertToAsyncPackageAsync(CodeFixContext context, Diagnostic diagnostic, CancellationToken cancellationToken) { SemanticModel semanticModel = await context.Document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); Compilation?compilation = await context.Document.Project.GetCompilationAsync(cancellationToken).ConfigureAwait(false); Assumes.NotNull(compilation); SyntaxNode root = await context.Document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); BaseTypeSyntax baseTypeSyntax = root.FindNode(diagnostic.Location.SourceSpan).FirstAncestorOrSelf <BaseTypeSyntax>(); ClassDeclarationSyntax classDeclarationSyntax = baseTypeSyntax.FirstAncestorOrSelf <ClassDeclarationSyntax>(); MethodDeclarationSyntax initializeMethodSyntax = classDeclarationSyntax.DescendantNodes() .OfType <MethodDeclarationSyntax>() .FirstOrDefault(method => method.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.OverrideKeyword)) && method.Identifier.Text == Types.Package.Initialize); InvocationExpressionSyntax?baseInitializeInvocationSyntax = initializeMethodSyntax?.Body?.DescendantNodes() .OfType <InvocationExpressionSyntax>() .FirstOrDefault(ies => ies.Expression is MemberAccessExpressionSyntax memberAccess && memberAccess.Name?.Identifier.Text == Types.Package.Initialize && memberAccess.Expression is BaseExpressionSyntax); var getServiceInvocationsSyntax = new List <InvocationExpressionSyntax>(); AttributeSyntax?packageRegistrationSyntax = null; { INamedTypeSymbol userClassSymbol = semanticModel.GetDeclaredSymbol(classDeclarationSyntax, context.CancellationToken); INamedTypeSymbol packageRegistrationType = compilation.GetTypeByMetadataName(Types.PackageRegistrationAttribute.FullName); AttributeData? packageRegistrationInstance = userClassSymbol?.GetAttributes().FirstOrDefault(a => Equals(a.AttributeClass, packageRegistrationType)); if (packageRegistrationInstance?.ApplicationSyntaxReference != null) { packageRegistrationSyntax = (AttributeSyntax)await packageRegistrationInstance.ApplicationSyntaxReference.GetSyntaxAsync(cancellationToken).ConfigureAwait(false); } } if (initializeMethodSyntax != null) { getServiceInvocationsSyntax.AddRange( from invocation in initializeMethodSyntax.DescendantNodes().OfType <InvocationExpressionSyntax>() let memberBinding = invocation.Expression as MemberAccessExpressionSyntax let identifierName = invocation.Expression as IdentifierNameSyntax where identifierName?.Identifier.Text == Types.Package.GetService || (memberBinding.Name.Identifier.Text == Types.Package.GetService && memberBinding.Expression.IsKind(SyntaxKind.ThisExpression)) select invocation); } // Make it easier to track nodes across changes. var nodesToTrack = new List <SyntaxNode?> { baseTypeSyntax, initializeMethodSyntax, baseInitializeInvocationSyntax, packageRegistrationSyntax, }; nodesToTrack.AddRange(getServiceInvocationsSyntax); nodesToTrack.RemoveAll(n => n == null); SyntaxNode updatedRoot = root.TrackNodes(nodesToTrack); // Replace the Package base type with AsyncPackage baseTypeSyntax = updatedRoot.GetCurrentNode(baseTypeSyntax); SimpleBaseTypeSyntax asyncPackageBaseTypeSyntax = SyntaxFactory.SimpleBaseType(Types.AsyncPackage.TypeSyntax.WithAdditionalAnnotations(Simplifier.Annotation)) .WithLeadingTrivia(baseTypeSyntax.GetLeadingTrivia()) .WithTrailingTrivia(baseTypeSyntax.GetTrailingTrivia()); updatedRoot = updatedRoot.ReplaceNode(baseTypeSyntax, asyncPackageBaseTypeSyntax); // Update the PackageRegistration attribute if (packageRegistrationSyntax != null) { LiteralExpressionSyntax trueExpression = SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression); packageRegistrationSyntax = updatedRoot.GetCurrentNode(packageRegistrationSyntax); AttributeArgumentSyntax allowsBackgroundLoadingSyntax = packageRegistrationSyntax.ArgumentList.Arguments.FirstOrDefault(a => a.NameEquals?.Name?.Identifier.Text == Types.PackageRegistrationAttribute.AllowsBackgroundLoading); if (allowsBackgroundLoadingSyntax != null) { updatedRoot = updatedRoot.ReplaceNode( allowsBackgroundLoadingSyntax, allowsBackgroundLoadingSyntax.WithExpression(trueExpression)); } else { updatedRoot = updatedRoot.ReplaceNode( packageRegistrationSyntax, packageRegistrationSyntax.AddArgumentListArguments( SyntaxFactory.AttributeArgument(trueExpression).WithNameEquals(SyntaxFactory.NameEquals(Types.PackageRegistrationAttribute.AllowsBackgroundLoading)))); } } // Find the Initialize override, if present, and update it to InitializeAsync if (initializeMethodSyntax != null) { IdentifierNameSyntax cancellationTokenLocalVarName = SyntaxFactory.IdentifierName("cancellationToken"); IdentifierNameSyntax progressLocalVarName = SyntaxFactory.IdentifierName("progress"); initializeMethodSyntax = updatedRoot.GetCurrentNode(initializeMethodSyntax); BlockSyntax newBody = initializeMethodSyntax.Body; SyntaxTriviaList leadingTrivia = SyntaxFactory.TriviaList( SyntaxFactory.Comment(@"// When initialized asynchronously, we *may* be on a background thread at this point."), SyntaxFactory.CarriageReturnLineFeed, SyntaxFactory.Comment(@"// Do any initialization that requires the UI thread after switching to the UI thread."), SyntaxFactory.CarriageReturnLineFeed, SyntaxFactory.Comment(@"// Otherwise, remove the switch to the UI thread if you don't need it."), SyntaxFactory.CarriageReturnLineFeed); ExpressionStatementSyntax switchToMainThreadStatement = SyntaxFactory.ExpressionStatement( SyntaxFactory.AwaitExpression( SyntaxFactory.InvocationExpression( SyntaxFactory.MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.ThisExpression(), SyntaxFactory.IdentifierName(Types.ThreadHelper.JoinableTaskFactory)), SyntaxFactory.IdentifierName(Types.JoinableTaskFactory.SwitchToMainThreadAsync))) .AddArgumentListArguments(SyntaxFactory.Argument(cancellationTokenLocalVarName)))) .WithLeadingTrivia(leadingTrivia) .WithTrailingTrivia(SyntaxFactory.CarriageReturnLineFeed); if (baseInitializeInvocationSyntax != null) { var baseInitializeAsyncInvocationBookmark = new SyntaxAnnotation(); AwaitExpressionSyntax baseInitializeAsyncInvocationSyntax = SyntaxFactory.AwaitExpression( baseInitializeInvocationSyntax .WithLeadingTrivia() .WithExpression( SyntaxFactory.MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.BaseExpression(), SyntaxFactory.IdentifierName(Types.AsyncPackage.InitializeAsync))) .AddArgumentListArguments( SyntaxFactory.Argument(cancellationTokenLocalVarName), SyntaxFactory.Argument(progressLocalVarName))) .WithLeadingTrivia(baseInitializeInvocationSyntax.GetLeadingTrivia()) .WithAdditionalAnnotations(baseInitializeAsyncInvocationBookmark); newBody = newBody.ReplaceNode(initializeMethodSyntax.GetCurrentNode(baseInitializeInvocationSyntax), baseInitializeAsyncInvocationSyntax); StatementSyntax baseInvocationStatement = newBody.GetAnnotatedNodes(baseInitializeAsyncInvocationBookmark).First().FirstAncestorOrSelf <StatementSyntax>(); newBody = newBody.InsertNodesAfter( baseInvocationStatement, new[] { switchToMainThreadStatement.WithLeadingTrivia(switchToMainThreadStatement.GetLeadingTrivia().Insert(0, SyntaxFactory.LineFeed)) }); } else { newBody = newBody.WithStatements( newBody.Statements.Insert(0, switchToMainThreadStatement)); } MethodDeclarationSyntax initializeAsyncMethodSyntax = initializeMethodSyntax .WithIdentifier(SyntaxFactory.Identifier(Types.AsyncPackage.InitializeAsync)) .WithReturnType(Types.Task.TypeSyntax.WithAdditionalAnnotations(Simplifier.Annotation)) .AddModifiers(SyntaxFactory.Token(SyntaxKind.AsyncKeyword)) .AddParameterListParameters( SyntaxFactory.Parameter(cancellationTokenLocalVarName.Identifier).WithType(Types.CancellationToken.TypeSyntax.WithAdditionalAnnotations(Simplifier.Annotation)), SyntaxFactory.Parameter(progressLocalVarName.Identifier).WithType(Types.IProgress.TypeSyntaxOf(Types.ServiceProgressData.TypeSyntax).WithAdditionalAnnotations(Simplifier.Annotation))) .WithBody(newBody); updatedRoot = updatedRoot.ReplaceNode(initializeMethodSyntax, initializeAsyncMethodSyntax); // Replace GetService calls with GetServiceAsync getServiceInvocationsSyntax = updatedRoot.GetCurrentNodes <InvocationExpressionSyntax>(getServiceInvocationsSyntax).ToList(); updatedRoot = updatedRoot.ReplaceNodes( getServiceInvocationsSyntax, (orig, node) => { InvocationExpressionSyntax invocation = node; if (invocation.Expression is IdentifierNameSyntax methodName) { invocation = invocation.WithExpression(SyntaxFactory.IdentifierName(Types.AsyncPackage.GetServiceAsync)); } else if (invocation.Expression is MemberAccessExpressionSyntax memberAccess) { invocation = invocation.WithExpression( memberAccess.WithName(SyntaxFactory.IdentifierName(Types.AsyncPackage.GetServiceAsync))); } return(SyntaxFactory.ParenthesizedExpression(SyntaxFactory.AwaitExpression(invocation)) .WithAdditionalAnnotations(Simplifier.Annotation)); }); updatedRoot = await Utils.AddUsingTaskEqualsDirectiveAsync(updatedRoot, cancellationToken); } Document newDocument = context.Document.WithSyntaxRoot(updatedRoot); newDocument = await ImportAdder.AddImportsAsync(newDocument, Simplifier.Annotation, cancellationToken : cancellationToken); return(newDocument); }
private BlockSyntax AddGuards(IFunctionTransformationResult transformResult, BlockSyntax methodBody, IFunctionAnalyzationResult methodResult, string cancellationTokenParamName) { if (!_configuration.Guards || methodBody == null || methodResult.Faulted) { return(methodBody); } // Avoid duplicate guards if the first statement also needs a guard var afterGuardIndex = methodBody.Statements.Count > methodResult.Preconditions.Count ? (int?)methodResult.Preconditions.Count : null; int?afterGuardStatementSpan = null; if (afterGuardIndex.HasValue) { // We have to update the methodBody node in order to have correct spans var afterGuardStatement = methodBody.Statements[afterGuardIndex.Value]; methodBody = methodBody.ReplaceNode(afterGuardStatement, afterGuardStatement.WithAdditionalAnnotations(new SyntaxAnnotation("AfterGuardStatement"))); afterGuardStatementSpan = methodBody.Statements[afterGuardIndex.Value].SpanStart; } // We need to get all statements that have at least one async invocation without a cancellation token argument, to prepend an extra guard var statements = new Dictionary <int, string>(); foreach (var functionReference in transformResult.TransformedFunctionReferences) { if (!(functionReference.AnalyzationResult is IBodyFunctionReferenceAnalyzationResult bodyFunctionReference)) { continue; } if (bodyFunctionReference.GetConversion() != ReferenceConversion.ToAsync || bodyFunctionReference.PassCancellationToken) { continue; } var statement = methodBody .GetAnnotatedNodes(functionReference.Annotation) .First().Ancestors().OfType <StatementSyntax>().First(); if (statements.ContainsKey(statement.SpanStart) || (afterGuardStatementSpan.HasValue && afterGuardStatementSpan.Value == statement.SpanStart)) { continue; } var annotation = Guid.NewGuid().ToString(); methodBody = methodBody .ReplaceNode(statement, statement.WithAdditionalAnnotations(new SyntaxAnnotation(annotation))); statements.Add(statement.SpanStart, annotation); } var startGuard = methodResult.OmitAsync ? GetSyncGuard(methodResult, cancellationTokenParamName, transformResult.BodyLeadingWhitespaceTrivia, transformResult.EndOfLineTrivia, transformResult.IndentTrivia) : GetAsyncGuard(cancellationTokenParamName, transformResult.BodyLeadingWhitespaceTrivia, transformResult.EndOfLineTrivia); methodBody = methodBody.WithStatements(methodBody.Statements.Insert(methodResult.Preconditions.Count, startGuard)); // For each statement we need to find the index where is located in the block. // TODO: Add support when the parent is not a block syntax foreach (var pair in statements) { var statement = methodBody.GetAnnotatedNodes(pair.Value).OfType <StatementSyntax>().First(); var parentBlock = statement.Parent as BlockSyntax; if (parentBlock == null) { continue; // Currently not supported } var index = parentBlock.Statements.IndexOf(statement); var newParentBlock = parentBlock .WithStatements(parentBlock.Statements .Insert(index, GetAsyncGuard(cancellationTokenParamName, statement.GetLeadingWhitespace(), transformResult.EndOfLineTrivia))); methodBody = methodBody .ReplaceNode(parentBlock, newParentBlock); } return(methodBody); }