public override ClassDeclarationSyntax ProcessApplyToClassDeclaration(ClassDeclarationSyntax applyTo) { applyTo = base.ProcessApplyToClassDeclaration(applyTo); if (this.applyTo.IsRecursiveParentOrDerivative) { // Add the lookupTable parameter to the constructor's signature. var origCtor = applyTo.Members.OfType<ConstructorDeclarationSyntax>().Single(); var alteredCtor = origCtor.AddParameterListParameters(SyntaxFactory.Parameter(LookupTableFieldName.Identifier).WithType(Syntax.OptionalOf(this.lookupTableType))); // If this type isn't itself the recursive parent then we derive from it. And we must propagate the value to the chained base type. if (!this.applyTo.IsRecursiveParent) { Assumes.NotNull(alteredCtor.Initializer); // we expect a chained call to the base constructor. alteredCtor = alteredCtor.WithInitializer( alteredCtor.Initializer.AddArgumentListArguments( SyntaxFactory.Argument(SyntaxFactory.NameColon(LookupTableFieldName), NoneToken, LookupTableFieldName))); } // Apply the updated constructor back to the generated type. applyTo = applyTo.ReplaceNode(origCtor, alteredCtor); // Search for invocations of the constructor that we now have to update. var invocations = ( from n in applyTo.DescendantNodes() let ctorInvocation = n as ObjectCreationExpressionSyntax let instantiatedTypeName = ctorInvocation?.Type?.ToString() where instantiatedTypeName == this.applyTo.TypeSyntax.ToString() || instantiatedTypeName == this.applyTo.TypeSymbol.Name select ctorInvocation).ToImmutableArray(); var trackedTree = applyTo.TrackNodes(invocations); var recursiveField = this.applyTo.RecursiveParent.RecursiveField; foreach (var ctorInvocation in invocations) { var currentInvocation = trackedTree.GetCurrentNode(ctorInvocation); ExpressionSyntax lookupTableValue = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression); var containingMethod = currentInvocation.FirstAncestorOrSelf<MethodDeclarationSyntax>(); if (containingMethod != null) { if (containingMethod.ParameterList.Parameters.Any(p => p.Identifier.ToString() == recursiveField.Name)) { // We're in a method that accepts the recursive field as a parameter. // The value we want to pass in for the lookup table is: // (children.IsDefined && children.Value != this.Children) ? default(Optional<ImmutableDictionary<uint, KeyValuePair<RecursiveType, uint>>>) : Optional.For(this.lookupTable); lookupTableValue = SyntaxFactory.ConditionalExpression( SyntaxFactory.ParenthesizedExpression( SyntaxFactory.BinaryExpression( SyntaxKind.LogicalAndExpression, Syntax.OptionalIsDefined(recursiveField.NameAsField), SyntaxFactory.BinaryExpression( SyntaxKind.NotEqualsExpression, Syntax.OptionalValue(recursiveField.NameAsField), Syntax.ThisDot(recursiveField.NameAsProperty)))), SyntaxFactory.DefaultExpression(Syntax.OptionalOf(this.lookupTableType)), Syntax.OptionalFor(Syntax.ThisDot(LookupTableFieldName))); } } var alteredInvocation = currentInvocation.AddArgumentListArguments( SyntaxFactory.Argument(SyntaxFactory.NameColon(LookupTableFieldName), NoneToken, lookupTableValue)); trackedTree = trackedTree.ReplaceNode(currentInvocation, alteredInvocation); } applyTo = trackedTree; } return applyTo; }
private static ClassDeclarationSyntax NewClassFactory(string className, ClassDeclarationSyntax classOld, MethodDeclarationSyntax methodOld) { var newParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier($"{className} {FirstLetteToLower(className)}")); var paremeters = SyntaxFactory.ParameterList(SyntaxFactory.SeparatedList<ParameterSyntax>().Add(newParameter)) .WithAdditionalAnnotations(Formatter.Annotation); var newMethod = SyntaxFactory.MethodDeclaration(methodOld.ReturnType, methodOld.Identifier.Text) .WithModifiers(methodOld.Modifiers) .WithParameterList(paremeters) .WithBody(methodOld.Body) .WithAdditionalAnnotations(Formatter.Annotation); var newClass = classOld.ReplaceNode(methodOld, newMethod); return newClass; }
private static ClassDeclarationSyntax UpdateClassToUseNewParameterClass(string className, ClassDeclarationSyntax classOld, MethodDeclarationSyntax methodOld) { var newParameter = SyntaxFactory.Parameter(SyntaxFactory.Identifier(FirstLetteToLower(className))).WithType(SyntaxFactory.ParseTypeName(className)); var parameters = SyntaxFactory.ParameterList(SyntaxFactory.SeparatedList<ParameterSyntax>().Add(newParameter)); var newMethod = methodOld.WithParameterList(parameters); var newClass = classOld.ReplaceNode(methodOld, newMethod); return newClass; }