public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) { if (!Settings.IsCodeFixEnabled(CodeFixIdentifiers.ChangeMethodReturnType)) { return; } SyntaxNode root = await context.GetSyntaxRootAsync().ConfigureAwait(false); if (!TryFindFirstAncestorOrSelf(root, context.Span, out SyntaxNode node, predicate: f => f.IsKind(SyntaxKind.MethodDeclaration, SyntaxKind.LocalFunctionStatement))) { return; } Diagnostic diagnostic = context.Diagnostics[0]; SemanticModel semanticModel = await context.GetSemanticModelAsync().ConfigureAwait(false); var methodSymbol = (IMethodSymbol)semanticModel.GetDeclaredSymbol(node, context.CancellationToken); Debug.Assert(methodSymbol != null, node.Kind().ToString()); ITypeSymbol typeSymbol = methodSymbol.ReturnType; if (typeSymbol.IsErrorType()) { return; } (bool containsReturnAwait, bool containsAwaitStatement) = AnalyzeAwaitExpressions(node); Debug.Assert(containsAwaitStatement || containsReturnAwait, node.ToString()); if (containsAwaitStatement) { INamedTypeSymbol taskSymbol = semanticModel.GetTypeByMetadataName(MetadataNames.System_Threading_Tasks_Task); CodeFixRegistrator.ChangeReturnType(context, diagnostic, node, taskSymbol, semanticModel, "Task"); } if (containsReturnAwait) { INamedTypeSymbol taskOfT = semanticModel.GetTypeByMetadataName(MetadataNames.System_Threading_Tasks_Task_T); typeSymbol = taskOfT.Construct(typeSymbol); CodeFixRegistrator.ChangeReturnType(context, diagnostic, node, typeSymbol, semanticModel, "TaskOfT"); } }
public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) { if (!Settings.IsCodeFixEnabled(CodeFixIdentifiers.ChangeMethodReturnType)) { return; } SyntaxNode root = await context.GetSyntaxRootAsync().ConfigureAwait(false); if (!TryFindFirstAncestorOrSelf(root, context.Span, out SyntaxNode node, predicate: f => f.IsKind(SyntaxKind.MethodDeclaration, SyntaxKind.LocalFunctionStatement))) { return; } foreach (Diagnostic diagnostic in context.Diagnostics) { switch (diagnostic.Id) { case CompilerDiagnosticIdentifiers.BodyCannotBeIteratorBlockBecauseTypeIsNotIteratorInterfaceType: { if (!Settings.IsCodeFixEnabled(CodeFixIdentifiers.ChangeMethodReturnType)) { break; } BlockSyntax body = (node is MethodDeclarationSyntax methodDeclaration) ? methodDeclaration.Body : ((LocalFunctionStatementSyntax)node).Body; Debug.Assert(body != null, node.ToString()); if (body == null) { break; } SemanticModel semanticModel = await context.GetSemanticModelAsync().ConfigureAwait(false); ITypeSymbol typeSymbol = null; HashSet <ITypeSymbol> typeSymbols = null; INamedTypeSymbol ienumerableOfTSymbol = null; foreach (SyntaxNode descendant in body.DescendantNodes(descendIntoChildren: f => !CSharpFacts.IsFunction(f.Kind()))) { if (!descendant.IsKind(SyntaxKind.YieldReturnStatement)) { continue; } var yieldReturn = (YieldStatementSyntax)descendant; ExpressionSyntax expression = yieldReturn.Expression; if (expression == null) { continue; } var namedTypeSymbol = semanticModel.GetTypeSymbol(expression, context.CancellationToken) as INamedTypeSymbol; if (namedTypeSymbol?.IsErrorType() != false) { continue; } if (typeSymbol == null) { typeSymbol = namedTypeSymbol; } else { if (typeSymbols == null) { typeSymbols = new HashSet <ITypeSymbol>() { typeSymbol }; } if (!typeSymbols.Add(namedTypeSymbol)) { continue; } } if (ienumerableOfTSymbol == null) { ienumerableOfTSymbol = semanticModel.GetTypeByMetadataName("System.Collections.Generic.IEnumerable`1"); } CodeFixRegistrator.ChangeReturnType( context, diagnostic, node, ienumerableOfTSymbol.Construct(namedTypeSymbol), semanticModel, additionalKey: SymbolDisplay.ToMinimalDisplayString(namedTypeSymbol, semanticModel, node.SpanStart, SymbolDisplayFormats.Default)); } break; } } } }