protected override async Task <Solution?> GetChangedSolutionAsync(CancellationToken cancellationToken)
            {
                var services           = _document.Project.Solution.Workspace.Services;
                var declarationService = _document.GetRequiredLanguageService <ISymbolDeclarationService>();
                var constructor        = declarationService.GetDeclarations(
                    _constructorCandidate.Constructor).Select(r => r.GetSyntax(cancellationToken)).First();

                var codeGenerator = _document.GetRequiredLanguageService <ICodeGenerationService>();
                var options       = await CodeGenerationOptions.FromDocumentAsync(CodeGenerationContext.Default, _document, cancellationToken).ConfigureAwait(false);

                var newConstructor = constructor;

                newConstructor = codeGenerator.AddParameters(newConstructor, _missingParameters, options, cancellationToken);
                newConstructor = codeGenerator.AddStatements(newConstructor, CreateAssignStatements(_constructorCandidate), options, cancellationToken)
                                 .WithAdditionalAnnotations(Formatter.Annotation);

                var syntaxTree = constructor.SyntaxTree;
                var newRoot    = syntaxTree.GetRoot(cancellationToken).ReplaceNode(constructor, newConstructor);

                // Make sure we get the document that contains the constructor we just updated
                var constructorDocument = _document.Project.GetDocument(syntaxTree);

                Contract.ThrowIfNull(constructorDocument);

                return(constructorDocument.WithSyntaxRoot(newRoot).Project.Solution);
            }
Example #2
0
            protected override async Task <Document> GetChangedDocumentAsync(CancellationToken cancellationToken)
            {
                using var _ = ArrayBuilder <IMethodSymbol> .GetInstance(out var methods);

                if (_generateEquals)
                {
                    methods.Add(await CreateEqualsMethodAsync(cancellationToken).ConfigureAwait(false));
                }

                var constructedTypeToImplement = await GetConstructedTypeToImplementAsync(cancellationToken).ConfigureAwait(false);

                if (constructedTypeToImplement is object)
                {
                    methods.Add(await CreateIEquatableEqualsMethodAsync(constructedTypeToImplement, cancellationToken).ConfigureAwait(false));
                }

                if (_generateGetHashCode)
                {
                    methods.Add(await CreateGetHashCodeMethodAsync(cancellationToken).ConfigureAwait(false));
                }

                if (_generateOperators)
                {
                    await AddOperatorsAsync(methods, cancellationToken).ConfigureAwait(false);
                }

                var codeGenerator  = _document.GetRequiredLanguageService <ICodeGenerationService>();
                var codeGenOptions = await CodeGenerationOptions.FromDocumentAsync(CodeGenerationContext.Default, _document, cancellationToken).ConfigureAwait(false);

                var formattingOptions = await SyntaxFormattingOptions.FromDocumentAsync(_document, cancellationToken).ConfigureAwait(false);

                var newTypeDeclaration = codeGenerator.AddMembers(_typeDeclaration, methods, codeGenOptions, cancellationToken);

                if (constructedTypeToImplement is object)
                {
                    var generator = _document.GetRequiredLanguageService <SyntaxGenerator>();

                    newTypeDeclaration = generator.AddInterfaceType(newTypeDeclaration,
                                                                    generator.TypeExpression(constructedTypeToImplement));
                }

                var newDocument = await UpdateDocumentAndAddImportsAsync(
                    _typeDeclaration, newTypeDeclaration, cancellationToken).ConfigureAwait(false);

                var service           = _document.GetRequiredLanguageService <IGenerateEqualsAndGetHashCodeService>();
                var formattedDocument = await service.FormatDocumentAsync(
                    newDocument, formattingOptions, cancellationToken).ConfigureAwait(false);

                return(formattedDocument);
            }
Example #3
0
        public static async Task <(Document containingDocument, SyntaxAnnotation typeAnnotation)> AddTypeToExistingFileAsync(Document document, INamedTypeSymbol newType, AnnotatedSymbolMapping symbolMapping, CancellationToken cancellationToken)
        {
            var originalRoot = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

            var typeDeclaration = originalRoot.GetAnnotatedNodes(symbolMapping.TypeNodeAnnotation).Single();
            var editor          = new SyntaxEditor(originalRoot, symbolMapping.AnnotatedSolution.Workspace.Services);

            var context = new CodeGenerationContext(generateMethodBodies: true);
            var options = await CodeGenerationOptions.FromDocumentAsync(context, document, cancellationToken).ConfigureAwait(false);

            var codeGenService = document.GetRequiredLanguageService <ICodeGenerationService>();
            var newTypeNode    = codeGenService.CreateNamedTypeDeclaration(newType, CodeGenerationDestination.Unspecified, options, cancellationToken)
                                 .WithAdditionalAnnotations(SimplificationHelpers.SimplifyModuleNameAnnotation);

            var typeAnnotation = new SyntaxAnnotation();

            newTypeNode = newTypeNode.WithAdditionalAnnotations(typeAnnotation);

            editor.InsertBefore(typeDeclaration, newTypeNode);

            var newDocument = document.WithSyntaxRoot(editor.GetChangedRoot());

            return(newDocument, typeAnnotation);
        }
        public static async Task MakeLocalFunctionStaticAsync(
            Document document,
            LocalFunctionStatementSyntax localFunction,
            ImmutableArray <ISymbol> captures,
            SyntaxEditor syntaxEditor,
            CancellationToken cancellationToken)
        {
            var root                = (await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false)) !;
            var semanticModel       = (await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false)) !;
            var localFunctionSymbol = semanticModel.GetDeclaredSymbol(localFunction, cancellationToken);

            Contract.ThrowIfNull(localFunctionSymbol, "We should have gotten a method symbol for a local function.");
            var documentImmutableSet = ImmutableHashSet.Create(document);

            // Finds all the call sites of the local function
            var referencedSymbols = await SymbolFinder.FindReferencesAsync(
                localFunctionSymbol, document.Project.Solution, documentImmutableSet, cancellationToken).ConfigureAwait(false);

            // Now we need to find all the references to the local function that we might need to fix.
            var shouldWarn = false;

            using var builderDisposer = ArrayBuilder <InvocationExpressionSyntax> .GetInstance(out var invocations);

            foreach (var referencedSymbol in referencedSymbols)
            {
                foreach (var location in referencedSymbol.Locations)
                {
                    // We limited the search scope to the single document,
                    // so all reference should be in the same tree.
                    var referenceNode = root.FindNode(location.Location.SourceSpan);
                    if (referenceNode is not IdentifierNameSyntax identifierNode)
                    {
                        // Unexpected scenario, skip and warn.
                        shouldWarn = true;
                        continue;
                    }

                    if (identifierNode.Parent is InvocationExpressionSyntax invocation)
                    {
                        invocations.Add(invocation);
                    }
                    else
                    {
                        // We won't be able to fix non-invocation references,
                        // e.g. creating a delegate.
                        shouldWarn = true;
                    }
                }
            }

            var parameterAndCapturedSymbols = CreateParameterSymbols(captures);

            // Fix all invocations by passing in additional arguments.
            foreach (var invocation in invocations)
            {
                syntaxEditor.ReplaceNode(
                    invocation,
                    (node, generator) =>
                {
                    var currentInvocation        = (InvocationExpressionSyntax)node;
                    var seenNamedArgument        = currentInvocation.ArgumentList.Arguments.Any(a => a.NameColon != null);
                    var seenDefaultArgumentValue = currentInvocation.ArgumentList.Arguments.Count < localFunction.ParameterList.Parameters.Count;

                    var newArguments = parameterAndCapturedSymbols.Select(
                        p => (ArgumentSyntax)generator.Argument(
                            seenNamedArgument || seenDefaultArgumentValue ? p.symbol.Name : null,
                            p.symbol.RefKind,
                            p.capture.Name.ToIdentifierName()));

                    var newArgList = currentInvocation.ArgumentList.WithArguments(currentInvocation.ArgumentList.Arguments.AddRange(newArguments));
                    return(currentInvocation.WithArgumentList(newArgList));
                });
            }

            // In case any of the captured variable isn't camel-cased,
            // we need to change the referenced name inside local function to use the new parameter's name.
            foreach (var(parameter, capture) in parameterAndCapturedSymbols)
            {
                if (parameter.Name == capture.Name)
                {
                    continue;
                }

                var referencedCaptureSymbols = await SymbolFinder.FindReferencesAsync(
                    capture, document.Project.Solution, documentImmutableSet, cancellationToken).ConfigureAwait(false);

                foreach (var referencedSymbol in referencedCaptureSymbols)
                {
                    foreach (var location in referencedSymbol.Locations)
                    {
                        var referenceSpan = location.Location.SourceSpan;
                        if (!localFunction.FullSpan.Contains(referenceSpan))
                        {
                            continue;
                        }

                        var referenceNode = root.FindNode(referenceSpan);
                        if (referenceNode is IdentifierNameSyntax identifierNode)
                        {
                            syntaxEditor.ReplaceNode(
                                identifierNode,
                                (node, generator) => generator.IdentifierName(parameter.Name.ToIdentifierToken()).WithTriviaFrom(node));
                        }
                    }
                }
            }

            var codeGenerator = document.GetRequiredLanguageService <ICodeGenerationService>();
            var options       = await CodeGenerationOptions.FromDocumentAsync(CodeGenerationContext.Default, document, cancellationToken).ConfigureAwait(false);

            // Updates the local function declaration with variables passed in as parameters
            syntaxEditor.ReplaceNode(
                localFunction,
                (node, generator) =>
            {
                var localFunctionWithNewParameters = codeGenerator.AddParameters(
                    node,
                    parameterAndCapturedSymbols.SelectAsArray(p => p.symbol),
                    options,
                    cancellationToken);

                if (shouldWarn)
                {
                    var annotation = WarningAnnotation.Create(CSharpCodeFixesResources.Warning_colon_Adding_parameters_to_local_function_declaration_may_produce_invalid_code);
                    localFunctionWithNewParameters = localFunctionWithNewParameters.WithAdditionalAnnotations(annotation);
                }

                return(AddStaticModifier(localFunctionWithNewParameters, CSharpSyntaxGenerator.Instance));
            });
        }
        private async Task <Document> ConvertAsync(Document document, TextSpan span, bool isRecord, CancellationToken cancellationToken)
        {
            var(anonymousObject, anonymousType) = await TryGetAnonymousObjectAsync(document, span, cancellationToken).ConfigureAwait(false);

            Debug.Assert(anonymousObject != null);
            Debug.Assert(anonymousType != null);

            var position = span.Start;
            var root     = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

            var semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);

            // Generate a unique name for the class we're creating.  We'll also add a rename
            // annotation so the user can pick the right name for the type afterwards.
            var className = NameGenerator.GenerateUniqueName(
                isRecord ? "NewRecord" : "NewClass",
                n => semanticModel.LookupSymbols(position, name: n).IsEmpty);

            // First, create the set of properties this class will have based on the properties the
            // anonymous type has itself.  Also, get a mapping of the original anonymous type's
            // properties to the new name we generated for it (if we converted camelCase to
            // PascalCase).
            var(properties, propertyMap) = GenerateProperties(document, anonymousType);

            // Next, generate the full class that will be used to replace all instances of this
            // anonymous type.
            var namedTypeSymbol = await GenerateFinalNamedTypeAsync(
                document, className, isRecord, properties, cancellationToken).ConfigureAwait(false);

            var generator = SyntaxGenerator.GetGenerator(document);
            var editor    = new SyntaxEditor(root, generator);

            var syntaxFacts      = document.GetRequiredLanguageService <ISyntaxFactsService>();
            var containingMember = anonymousObject.FirstAncestorOrSelf <SyntaxNode, ISyntaxFactsService>((node, syntaxFacts) => syntaxFacts.IsMethodLevelMember(node), syntaxFacts) ?? anonymousObject;

            // Next, go and update any references to these anonymous type properties to match
            // the new PascalCased name we've picked for the new properties that will go in
            // the named type.
            await ReplacePropertyReferencesAsync(
                document, editor, containingMember,
                propertyMap, cancellationToken).ConfigureAwait(false);

            // Next, go through and replace all matching anonymous types in this method with a call
            // to construct the new named type we've generated.
            await ReplaceMatchingAnonymousTypesAsync(
                document, editor, namedTypeSymbol,
                containingMember, anonymousObject,
                anonymousType, cancellationToken).ConfigureAwait(false);

            var context = new CodeGenerationContext(
                generateMembers: true,
                sortMembers: false,
                autoInsertionLocation: false);

            var codeGenOptions = await CodeGenerationOptions.FromDocumentAsync(context, document, cancellationToken).ConfigureAwait(false);

            var codeGenService = document.GetRequiredLanguageService <ICodeGenerationService>();

            // Then, actually insert the new class in the appropriate container.
            var container = anonymousObject.GetAncestor <TNamespaceDeclarationSyntax>() ?? root;

            editor.ReplaceNode(container, (currentContainer, _) =>
                               codeGenService.AddNamedType(currentContainer, namedTypeSymbol, codeGenOptions, cancellationToken));

            var updatedDocument = document.WithSyntaxRoot(editor.GetChangedRoot());

            // Finally, format using the equals+getHashCode service so that our generated methods
            // follow any special formatting rules specific to them.
            var equalsAndGetHashCodeService = document.GetRequiredLanguageService <IGenerateEqualsAndGetHashCodeService>();

            return(await equalsAndGetHashCodeService.FormatDocumentAsync(
                       updatedDocument, cancellationToken).ConfigureAwait(false));
        }