/// <inheritdoc /> public override void Initialize(AnalysisContext context) { context.EnableConcurrentExecution(); context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze); context.RegisterCompilationStartAction(compilationStartContext => { var mainThreadAssertingMethods = CommonInterest.ReadMethods(compilationStartContext.Options, CommonInterest.FileNamePatternForMethodsThatAssertMainThread, compilationStartContext.CancellationToken).ToImmutableArray(); var mainThreadSwitchingMethods = CommonInterest.ReadMethods(compilationStartContext.Options, CommonInterest.FileNamePatternForMethodsThatSwitchToMainThread, compilationStartContext.CancellationToken).ToImmutableArray(); var membersRequiringMainThread = CommonInterest.ReadTypesAndMembers(compilationStartContext.Options, CommonInterest.FileNamePatternForMembersRequiringMainThread, compilationStartContext.CancellationToken).ToImmutableArray(); var diagnosticProperties = ImmutableDictionary <string, string> .Empty .Add(CommonInterest.FileNamePatternForMethodsThatAssertMainThread.ToString(), string.Join("\n", mainThreadAssertingMethods)) .Add(CommonInterest.FileNamePatternForMethodsThatSwitchToMainThread.ToString(), string.Join("\n", mainThreadSwitchingMethods)); var methodsDeclaringUIThreadRequirement = new HashSet <IMethodSymbol>(); var methodsAssertingUIThreadRequirement = new HashSet <IMethodSymbol>(); var callerToCalleeMap = new Dictionary <IMethodSymbol, List <CallInfo> >(); compilationStartContext.RegisterCodeBlockStartAction <SyntaxKind>(codeBlockStartContext => { var methodAnalyzer = new MethodAnalyzer { MainThreadAssertingMethods = mainThreadAssertingMethods, MainThreadSwitchingMethods = mainThreadSwitchingMethods, MembersRequiringMainThread = membersRequiringMainThread, MethodsDeclaringUIThreadRequirement = methodsDeclaringUIThreadRequirement, MethodsAssertingUIThreadRequirement = methodsAssertingUIThreadRequirement, DiagnosticProperties = diagnosticProperties, }; codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeInvocation), SyntaxKind.InvocationExpression); codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeMemberAccess), SyntaxKind.SimpleMemberAccessExpression); codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeCast), SyntaxKind.CastExpression); codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeAs), SyntaxKind.AsExpression); codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeAs), SyntaxKind.IsExpression); codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeIsPattern), SyntaxKind.IsPatternExpression); }); compilationStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(c => AddToCallerCalleeMap(c, callerToCalleeMap)), SyntaxKind.InvocationExpression); compilationStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(c => AddToCallerCalleeMap(c, callerToCalleeMap)), SyntaxKind.SimpleMemberAccessExpression); compilationStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(c => AddToCallerCalleeMap(c, callerToCalleeMap)), SyntaxKind.IdentifierName); compilationStartContext.RegisterCompilationEndAction(compilationEndContext => { var calleeToCallerMap = CreateCalleeToCallerMap(callerToCalleeMap); var transitiveClosureOfMainThreadRequiringMethods = GetTransitiveClosureOfMainThreadRequiringMethods(methodsAssertingUIThreadRequirement, calleeToCallerMap); foreach (var implicitUserMethod in transitiveClosureOfMainThreadRequiringMethods.Except(methodsDeclaringUIThreadRequirement)) { var reportSites = from info in callerToCalleeMap[implicitUserMethod] where transitiveClosureOfMainThreadRequiringMethods.Contains(info.MethodSymbol) group info by info.MethodSymbol into bySymbol select new { Location = bySymbol.First().InvocationSyntax.GetLocation(), CalleeMethod = bySymbol.Key }; foreach (var site in reportSites) { bool isAsync = Utils.IsAsyncReady(implicitUserMethod); var descriptor = isAsync ? DescriptorAsync : DescriptorSync; string calleeName = Utils.GetFullName(site.CalleeMethod); var formattingArgs = isAsync ? new object[] { calleeName } : new object[] { calleeName, mainThreadAssertingMethods.FirstOrDefault() }; Diagnostic diagnostic = Diagnostic.Create( descriptor, site.Location, diagnosticProperties, formattingArgs); compilationEndContext.ReportDiagnostic(diagnostic); } } }); }); }
public override async Task RegisterCodeFixesAsync(CodeFixContext context) { foreach (var diagnostic in context.Diagnostics) { var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); var syntaxNode = (ExpressionSyntax)root.FindNode(diagnostic.Location.SourceSpan, getInnermostNodeForTie: true); var container = Utils.GetContainingFunction(syntaxNode); if (container.BlockOrExpression == null) { return; } if (!container.IsAsync) { if (!(container.Function is MethodDeclarationSyntax || container.Function is AnonymousFunctionExpressionSyntax)) { // We don't support converting whatever this is into an async method. return; } } var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false); var enclosingSymbol = semanticModel.GetEnclosingSymbol(diagnostic.Location.SourceSpan.Start, context.CancellationToken); if (enclosingSymbol == null) { return; } var options = await CommonInterest.ReadMethodsAsync(context, CommonInterest.FileNamePatternForMethodsThatSwitchToMainThread, context.CancellationToken); int positionForLookup = diagnostic.Location.SourceSpan.Start; ISymbol cancellationTokenSymbol = Utils.FindCancellationToken(semanticModel, positionForLookup, context.CancellationToken).FirstOrDefault(); foreach (var option in options) { // We're looking for methods that either require no parameters, // or (if we have one to give) that have just one parameter that is a CancellationToken. var proposedMethod = Utils.FindMethodGroup(semanticModel, option) .FirstOrDefault(m => !m.Parameters.Any(p => !p.HasExplicitDefaultValue) || (cancellationTokenSymbol != null && m.Parameters.Length == 1 && Utils.IsCancellationTokenParameter(m.Parameters[0]))); if (proposedMethod == null) { // We can't find it, so don't offer to use it. continue; } if (proposedMethod.IsStatic) { OfferFix(option.ToString()); } else { foreach (var candidate in Utils.FindInstanceOf(proposedMethod.ContainingType, semanticModel, positionForLookup, context.CancellationToken)) { if (candidate.Item1) { OfferFix($"{candidate.Item2.Name}.{proposedMethod.Name}"); } else { OfferFix($"{candidate.Item2.ContainingNamespace}.{candidate.Item2.ContainingType.Name}.{candidate.Item2.Name}.{proposedMethod.Name}"); } } } void OfferFix(string fullyQualifiedMethod) { context.RegisterCodeFix(CodeAction.Create($"Use 'await {fullyQualifiedMethod}'", ct => Fix(fullyQualifiedMethod, proposedMethod, ct), fullyQualifiedMethod), context.Diagnostics); } } async Task <Solution> Fix(string fullyQualifiedMethod, IMethodSymbol methodSymbol, CancellationToken cancellationToken) { var assertionStatementToRemove = syntaxNode.FirstAncestorOrSelf <StatementSyntax>(); int typeAndMethodDelimiterIndex = fullyQualifiedMethod.LastIndexOf('.'); IdentifierNameSyntax methodName = SyntaxFactory.IdentifierName(fullyQualifiedMethod.Substring(typeAndMethodDelimiterIndex + 1)); ExpressionSyntax invokedMethod = Utils.MemberAccess(fullyQualifiedMethod.Substring(0, typeAndMethodDelimiterIndex).Split('.'), methodName) .WithAdditionalAnnotations(Simplifier.Annotation); var invocationExpression = SyntaxFactory.InvocationExpression(invokedMethod); var cancellationTokenParameter = methodSymbol.Parameters.FirstOrDefault(Utils.IsCancellationTokenParameter); if (cancellationTokenParameter != null && cancellationTokenSymbol != null) { var arg = SyntaxFactory.Argument(SyntaxFactory.IdentifierName(cancellationTokenSymbol.Name)); if (methodSymbol.Parameters.IndexOf(cancellationTokenParameter) > 0) { arg = arg.WithNameColon(SyntaxFactory.NameColon(SyntaxFactory.IdentifierName(cancellationTokenParameter.Name))); } invocationExpression = invocationExpression.AddArgumentListArguments(arg); } ExpressionSyntax awaitExpression = SyntaxFactory.AwaitExpression(invocationExpression); var addedStatement = SyntaxFactory.ExpressionStatement(awaitExpression) .WithAdditionalAnnotations(Simplifier.Annotation, Formatter.Annotation); var methodAnnotation = new SyntaxAnnotation(); CSharpSyntaxNode methodSyntax = container.Function.ReplaceNode(assertionStatementToRemove, addedStatement) .WithAdditionalAnnotations(methodAnnotation); Document newDocument = context.Document.WithSyntaxRoot(root.ReplaceNode(container.Function, methodSyntax)); var newSyntaxRoot = await newDocument.GetSyntaxRootAsync(cancellationToken); methodSyntax = (CSharpSyntaxNode)newSyntaxRoot.GetAnnotatedNodes(methodAnnotation).Single(); if (!container.IsAsync) { switch (methodSyntax) { case AnonymousFunctionExpressionSyntax anonFunc: semanticModel = await newDocument.GetSemanticModelAsync(cancellationToken); methodSyntax = Utils.MakeMethodAsync(anonFunc, semanticModel, cancellationToken); newDocument = newDocument.WithSyntaxRoot(newSyntaxRoot.ReplaceNode(anonFunc, methodSyntax)); break; case MethodDeclarationSyntax methodDecl: (newDocument, methodSyntax) = await Utils.MakeMethodAsync(methodDecl, newDocument, cancellationToken); break; } } return(newDocument.Project.Solution); } } }