private MemberDeclarationSyntax CreateToDerivedTypeMethod(MetaType derivedType)
            {
                var derivedTypeName = GetFullyQualifiedSymbolName(derivedType.TypeSymbol);
                var thatLocal       = SyntaxFactory.IdentifierName("that");
                var body            = new List <StatementSyntax>();

                // var that = this as DerivedType;
                body.Add(SyntaxFactory.LocalDeclarationStatement(
                             SyntaxFactory.VariableDeclaration(
                                 varType,
                                 SyntaxFactory.SingletonSeparatedList(
                                     SyntaxFactory.VariableDeclarator(thatLocal.Identifier)
                                     .WithInitializer(SyntaxFactory.EqualsValueClause(
                                                          SyntaxFactory.BinaryExpression(
                                                              SyntaxKind.AsExpression,
                                                              SyntaxFactory.ThisExpression(),
                                                              derivedTypeName)))))));

                // this.GetType()
                var thisDotGetType = SyntaxFactory.InvocationExpression(
                    SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, SyntaxFactory.ThisExpression(), SyntaxFactory.IdentifierName("GetType")),
                    SyntaxFactory.ArgumentList());

                // {0}.Equals(typeof(derivedType))
                var thisTypeIsEquivalentToDerivedType =
                    SyntaxFactory.InvocationExpression(
                        SyntaxFactory.MemberAccessExpression(
                            SyntaxKind.SimpleMemberAccessExpression,
                            thisDotGetType,
                            SyntaxFactory.IdentifierName(nameof(Type.Equals))),
                        SyntaxFactory.ArgumentList(SyntaxFactory.SingletonSeparatedList(SyntaxFactory.Argument(
                                                                                            SyntaxFactory.TypeOfExpression(derivedTypeName)))));

                var ifEquivalentTypeBlock = new List <StatementSyntax>();
                var fieldsBeyond          = derivedType.GetFieldsBeyond(this.generator.applyToMetaType);

                if (fieldsBeyond.Any())
                {
                    Func <MetaField, ExpressionSyntax> isUnchanged = v =>
                                                                     SyntaxFactory.ParenthesizedExpression(
                        v.IsRequired
                                ? // ({0} == that.{1})
                        SyntaxFactory.BinaryExpression(
                            SyntaxKind.EqualsExpression,
                            v.NameAsField,
                            SyntaxFactory.MemberAccessExpression(
                                SyntaxKind.SimpleMemberAccessExpression,
                                thatLocal,
                                v.NameAsProperty))
                                : // (!{0}.IsDefined || {0}.Value == that.{1})
                        SyntaxFactory.BinaryExpression(
                            SyntaxKind.LogicalOrExpression,
                            SyntaxFactory.PrefixUnaryExpression(SyntaxKind.LogicalNotExpression, Syntax.OptionalIsDefined(v.NameAsField)),
                            SyntaxFactory.BinaryExpression(
                                SyntaxKind.EqualsExpression,
                                Syntax.OptionalValue(v.NameAsField),
                                SyntaxFactory.MemberAccessExpression(
                                    SyntaxKind.SimpleMemberAccessExpression,
                                    thatLocal,
                                    v.NameAsProperty))));
                    var noChangesExpression = fieldsBeyond.Select(isUnchanged).ChainBinaryExpressions(SyntaxKind.LogicalAndExpression);

                    ifEquivalentTypeBlock.Add(SyntaxFactory.IfStatement(
                                                  noChangesExpression,
                                                  SyntaxFactory.ReturnStatement(thatLocal)));
                }
                else
                {
                    ifEquivalentTypeBlock.Add(SyntaxFactory.ReturnStatement(thatLocal));
                }

                // if (that != null && this.GetType().IsEquivalentTo(typeof(derivedType))) { ... }
                body.Add(SyntaxFactory.IfStatement(
                             SyntaxFactory.BinaryExpression(
                                 SyntaxKind.LogicalAndExpression,
                                 SyntaxFactory.BinaryExpression(SyntaxKind.NotEqualsExpression, thatLocal, SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)),
                                 thisTypeIsEquivalentToDerivedType),
                             SyntaxFactory.Block(ifEquivalentTypeBlock)));

                // return DerivedType.CreateWithIdentity(...)
                body.Add(SyntaxFactory.ReturnStatement(
                             SyntaxFactory.InvocationExpression(
                                 SyntaxFactory.MemberAccessExpression(
                                     SyntaxKind.SimpleMemberAccessExpression,
                                     derivedTypeName,
                                     CreateWithIdentityMethodName),
                                 this.generator.CreateArgumentList(this.generator.applyToMetaType.AllFields, asOptional: OptionalStyle.WhenNotRequired)
                                 .AddArguments(RequiredIdentityArgumentFromProperty)
                                 .AddArguments(this.generator.CreateArgumentList(fieldsBeyond, ArgSource.Argument).Arguments.ToArray()))));

                return(SyntaxFactory.MethodDeclaration(
                           derivedTypeName,
                           GetToTypeMethodName(derivedType.TypeSymbol.Name).Identifier)
                       .AddModifiers(
                           SyntaxFactory.Token(SyntaxKind.PublicKeyword),
                           SyntaxFactory.Token(SyntaxKind.VirtualKeyword))
                       .AddAttributeLists(PureAttributeList)
                       .WithParameterList(this.generator.CreateParameterList(fieldsBeyond, ParameterStyle.OptionalOrRequired))
                       .WithBody(SyntaxFactory.Block(body)));
            }
Example #2
0
            public override ClassDeclarationSyntax ProcessApplyToClassDeclaration(ClassDeclarationSyntax applyTo)
            {
                applyTo = base.ProcessApplyToClassDeclaration(applyTo);

                if (this.applyTo.IsRecursiveParentOrDerivative)
                {
                    // Add the lookupTable parameter to the constructor's signature.
                    var origCtor    = GetMeaningfulConstructor(applyTo);
                    var alteredCtor = origCtor.AddParameterListParameters(SyntaxFactory.Parameter(LookupTableFieldName.Identifier).WithType(Syntax.OptionalOf(this.lookupTableType)));

                    // If this type isn't itself the recursive parent then we derive from it. And we must propagate the value to the chained base type.
                    if (!this.applyTo.IsRecursiveParent)
                    {
                        Assumes.NotNull(alteredCtor.Initializer); // we expect a chained call to the base constructor.
                        alteredCtor = alteredCtor.WithInitializer(
                            alteredCtor.Initializer.AddArgumentListArguments(
                                SyntaxFactory.Argument(SyntaxFactory.NameColon(LookupTableFieldName), NoneToken, LookupTableFieldName)));
                    }

                    // Apply the updated constructor back to the generated type.
                    applyTo = applyTo.ReplaceNode(origCtor, alteredCtor);

                    // Search for invocations of the constructor that we now have to update.
                    var creationInvocations = (
                        from n in applyTo.DescendantNodes()
                        let ctorInvocation = n as ObjectCreationExpressionSyntax
                                             let instantiatedTypeName = ctorInvocation?.Type?.ToString()
                                                                        where instantiatedTypeName == this.applyTo.TypeSyntax.ToString() || instantiatedTypeName == this.applyTo.TypeSymbol.Name
                                                                        select ctorInvocation).ToImmutableArray();
                    var chainedInvocations = (
                        from n in applyTo.DescendantNodes()
                        let chained = n as ConstructorInitializerSyntax
                                      where chained.IsKind(SyntaxKind.ThisConstructorInitializer) && chained.FirstAncestorOrSelf <ConstructorDeclarationSyntax>().Identifier.ValueText == this.applyTo.TypeSymbol.Name
                                      select chained).ToImmutableArray();
                    var invocations = creationInvocations.Concat <CSharpSyntaxNode>(chainedInvocations);
                    var trackedTree = applyTo.TrackNodes(invocations);

                    var recursiveField = this.applyTo.RecursiveParent.RecursiveField;
                    foreach (var ctorInvocation in invocations)
                    {
                        ExpressionSyntax lookupTableValue = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);

                        var currentInvocation         = trackedTree.GetCurrentNode(ctorInvocation);
                        var currentCreationInvocation = currentInvocation as ObjectCreationExpressionSyntax;
                        var currentChainedInvocation  = currentInvocation as ConstructorInitializerSyntax;

                        if (currentCreationInvocation != null)
                        {
                            var containingMethod = currentInvocation.FirstAncestorOrSelf <MethodDeclarationSyntax>();
                            if (containingMethod != null)
                            {
                                if (containingMethod.ParameterList.Parameters.Any(p => p.Identifier.ToString() == recursiveField.Name))
                                {
                                    // We're in a method that accepts the recursive field as a parameter.
                                    // The value we want to pass in for the lookup table is:
                                    // (children.IsDefined && children.Value != this.Children) ? default(Optional<ImmutableDictionary<uint, KeyValuePair<RecursiveType, uint>>>) : Optional.For(this.lookupTable);
                                    lookupTableValue = SyntaxFactory.ConditionalExpression(
                                        SyntaxFactory.ParenthesizedExpression(
                                            SyntaxFactory.BinaryExpression(
                                                SyntaxKind.LogicalAndExpression,
                                                Syntax.OptionalIsDefined(recursiveField.NameAsField),
                                                SyntaxFactory.BinaryExpression(
                                                    SyntaxKind.NotEqualsExpression,
                                                    Syntax.OptionalValue(recursiveField.NameAsField),
                                                    Syntax.ThisDot(recursiveField.NameAsProperty)))),
                                        SyntaxFactory.DefaultExpression(Syntax.OptionalOf(this.lookupTableType)),
                                        Syntax.OptionalFor(Syntax.ThisDot(LookupTableFieldName)));
                                }
                            }

                            var alteredInvocation = currentCreationInvocation.AddArgumentListArguments(
                                SyntaxFactory.Argument(SyntaxFactory.NameColon(LookupTableFieldName), NoneToken, lookupTableValue));
                            trackedTree = trackedTree.ReplaceNode(currentInvocation, alteredInvocation);
                        }
                        else
                        {
                            var alteredInvocation = currentChainedInvocation.AddArgumentListArguments(
                                SyntaxFactory.Argument(SyntaxFactory.NameColon(LookupTableFieldName), NoneToken, lookupTableValue));
                            trackedTree = trackedTree.ReplaceNode(currentInvocation, alteredInvocation);
                        }
                    }

                    applyTo = trackedTree;
                }

                return(applyTo);
            }
            protected IReadOnlyList <MemberDeclarationSyntax> CreateMutableProperties()
            {
                var properties = new List <PropertyDeclarationSyntax>();

                foreach (var field in this.generator.applyToMetaType.LocalFields)
                {
                    var thisField = Syntax.ThisDot(field.NameAsField);
                    var optionalFieldNotYetDefined = SyntaxFactory.PrefixUnaryExpression(SyntaxKind.LogicalNotExpression, Syntax.OptionalIsDefined(thisField));
                    var getterBlock = field.IsGeneratedImmutableType
                        ? SyntaxFactory.Block(
                        // if (!this.fieldName.IsDefined) {
                        SyntaxFactory.IfStatement(
                            optionalFieldNotYetDefined,
                            SyntaxFactory.Block(
                                // this.fieldName = this.immutable.fieldName?.ToBuilder();
                                SyntaxFactory.ExpressionStatement(SyntaxFactory.AssignmentExpression(
                                                                      SyntaxKind.SimpleAssignmentExpression,
                                                                      thisField,
                                                                      SyntaxFactory.ConditionalAccessExpression(
                                                                          SyntaxFactory.MemberAccessExpression(
                                                                              SyntaxKind.SimpleMemberAccessExpression,
                                                                              Syntax.ThisDot(ImmutableFieldName),
                                                                              field.NameAsField),
                                                                          SyntaxFactory.InvocationExpression(
                                                                              SyntaxFactory.MemberBindingExpression(ToBuilderMethodName),
                                                                              SyntaxFactory.ArgumentList())))))),
                        SyntaxFactory.ReturnStatement(Syntax.OptionalValue(thisField)))
                        : SyntaxFactory.Block(SyntaxFactory.ReturnStatement(thisField));
                    var setterValueArg  = SyntaxFactory.IdentifierName("value");
                    var setterCondition = field.IsGeneratedImmutableType ?
                                          SyntaxFactory.BinaryExpression(
                        SyntaxKind.LogicalOrExpression,
                        optionalFieldNotYetDefined,
                        SyntaxFactory.BinaryExpression(
                            SyntaxKind.NotEqualsExpression,
                            Syntax.OptionalValue(thisField),
                            setterValueArg)) :
                                          HasEqualityOperators(field.Symbol.Type) ?
                                          SyntaxFactory.BinaryExpression(
                        SyntaxKind.NotEqualsExpression,
                        thisField,
                        setterValueArg) :
                                          null;
                    var setterSignificantBlock = SyntaxFactory.Block(
                        SyntaxFactory.ExpressionStatement(SyntaxFactory.AssignmentExpression(
                                                              SyntaxKind.SimpleAssignmentExpression,
                                                              thisField,
                                                              setterValueArg)),
                        SyntaxFactory.ExpressionStatement(
                            SyntaxFactory.InvocationExpression(
                                SyntaxFactory.MemberAccessExpression(
                                    SyntaxKind.SimpleMemberAccessExpression,
                                    SyntaxFactory.ThisExpression(),
                                    OnPropertyChangedMethodName))));
                    var setterBlock = setterCondition != null?
                                      SyntaxFactory.Block(
                        SyntaxFactory.IfStatement(
                            setterCondition,
                            setterSignificantBlock)) :
                                          setterSignificantBlock;

                    var property = SyntaxFactory.PropertyDeclaration(
                        this.GetPropertyTypeForBuilder(field),
                        field.Name.ToPascalCase())
                                   .AddModifiers(SyntaxFactory.Token(SyntaxKind.PublicKeyword))
                                   .WithAccessorList(SyntaxFactory.AccessorList(SyntaxFactory.List(new AccessorDeclarationSyntax[]
                    {
                        SyntaxFactory.AccessorDeclaration(SyntaxKind.GetAccessorDeclaration, getterBlock),
                        SyntaxFactory.AccessorDeclaration(SyntaxKind.SetAccessorDeclaration, setterBlock),
                    })));
                    properties.Add(property);
                }

                return(properties);
            }