Example #1
0
        public static SourceTypeInfo?TryCreate(
            GeneratorExecutionContext context,
            TypeDeclarationSyntax syntax,
            KnownTypes knownTypes,
            GenerateOptions?options,
            out ImmutableArray <Diagnostic> diagnostics)
        {
            if (syntax is null)
            {
                throw new ArgumentNullException(nameof(syntax));
            }

            if (knownTypes is null)
            {
                throw new ArgumentNullException(nameof(knownTypes));
            }

            var diagnosticsBuilder = ImmutableArray.CreateBuilder <Diagnostic>();

            try
            {
                var semanticModel = context.Compilation.GetSemanticModel(syntax.SyntaxTree);
                var symbol        = ModelExtensions.GetDeclaredSymbol(semanticModel, syntax, context.CancellationToken);

                if (symbol is not INamedTypeSymbol ts)
                {
                    // TODO: Diagnostics ?
                    return(null);
                }

                string fullTypeName = ts.GetFullName();
                var    location     = syntax.GetLocation();

                var equatableAttribute  = knownTypes.GetEquatableAttribute(ts);
                var comparableAttribute = knownTypes.GetComparableAttribute(ts);

                var attribute = comparableAttribute;

                if (attribute is null)
                {
                    attribute = equatableAttribute;
                }
                else if (equatableAttribute is not null)
                {
                    // TODO: Diagnostic
                    // TODO: CodeFix
                }

                if (attribute is null)
                {
                    return(null);
                }

                if (!syntax.Modifiers.Any(SyntaxKind.PartialKeyword))
                {
                    diagnosticsBuilder.Add(
                        DiagnosticFactory.TypeIsNotPartial(
                            fullTypeName,
                            location));

                    return(null);
                }

                if (syntax.Modifiers.Any(SyntaxKind.StaticKeyword))
                {
                    diagnosticsBuilder.Add(
                        DiagnosticFactory.TypeIsStatic(
                            fullTypeName,
                            location));

                    return(null);
                }

                var enclosingTypes = new List <INamedTypeSymbol>();
                var enclosingType  = ts.ContainingType;

                while (enclosingType is not null)
                {
                    var invalidSyntaxLocations = enclosingType.DeclaringSyntaxReferences
                                                 .Select(x => x.GetSyntax(context.CancellationToken))
                                                 .OfType <TypeDeclarationSyntax>()
                                                 .Where(x => !x.Modifiers.Any(SyntaxKind.PartialKeyword))
                                                 .Select(x => x.GetLocation())
                                                 .ToArray();

                    if (invalidSyntaxLocations.Length > 0)
                    {
                        diagnosticsBuilder.Add(
                            DiagnosticFactory.TypeIsNotPartial(
                                enclosingType.GetFullName(),
                                invalidSyntaxLocations[0],
                                invalidSyntaxLocations.Skip(1)));

                        return(null);
                    }

                    enclosingTypes.Add(enclosingType);
                    enclosingType = enclosingType.ContainingType;
                }

                enclosingTypes.Reverse();

                var members = new List <SourceMemberInfo>();

                foreach (var member in ts.GetMembers())
                {
                    var compareByAttribute = knownTypes.GetCompareByAttribute(member);
                    if (compareByAttribute is null)
                    {
                        continue;
                    }

                    if (member is IPropertySymbol ps)
                    {
                        members.Add(new SourceMemberInfo(ps, compareByAttribute));
                    }
                    else if (member is IFieldSymbol fs)
                    {
                        members.Add(new SourceMemberInfo(fs, compareByAttribute));
                    }
                    else
                    {
                        // TODO: Diagnostic?
                        return(null);
                    }
                }

                if (members.Count == 0)
                {
                    diagnosticsBuilder.Add(
                        DiagnosticFactory.NoMembers(
                            fullTypeName,
                            location));

                    return(null);
                }

                options ??= new GenerateOptions(context, syntax, attribute);

                var nullableContext = semanticModel.GetNullableContext(syntax.SpanStart);

                var sortedMembers = members
                                    .OrderBy(x => x.ComparisonOrder)
                                    .ThenBy(x => x.Name)
                                    .ToArray();

                var sourceTypeInfo = new SourceTypeInfo(
                    syntax,
                    location,
                    ts,
                    (equatableAttribute is not null),
                    (comparableAttribute is not null),
                    enclosingTypes,
                    sortedMembers,
                    nullableContext,
                    options,
                    knownTypes);

                return(sourceTypeInfo);
            }
            finally
            {
                diagnostics = diagnosticsBuilder.ToImmutable();
            }
        }
        public void Execute(
            GeneratorExecutionContext context)
        {
            LaunchDebugger(context);

            if (context.SyntaxReceiver is not SyntaxReceiver receiver)
            {
                return;
            }

            var knownTypes     = new KnownTypes(context.Compilation);
            var candidateTypes = new List <SourceTypeInfo>();

            foreach (var candidateSyntax in receiver.CandidateSyntaxes)
            {
                var sourceTypeInfo = SourceTypeInfo.TryCreate(
                    context,
                    candidateSyntax,
                    knownTypes,
                    this._options,
                    out var diagnostics);

                context.ReportDiagnostics(diagnostics);

                if (sourceTypeInfo is null)
                {
                    continue;
                }

                candidateTypes.Add(sourceTypeInfo);
            }

            var typeMap = candidateTypes.ToDictionary(
                x => x.TypeSymbol,
                (IEqualityComparer <ITypeSymbol>)SymbolEqualityComparer.Default);

            foreach (var candidateType in candidateTypes.ToArray())
            {
                if (!candidateType.HasComparableAttribute)
                {
                    continue;
                }

                var options = candidateType.GenerateOptions;

                foreach (var member in candidateType.Members)
                {
                    if (options.GenerateGenericComparable ||
                        options.GenerateNonGenericComparable ||
                        options.GenerateOperators ||
                        options.GenerateStructuralComparable)
                    {
                        var memberType = member.Type;

                        if (knownTypes.IsGenericComparable(memberType))
                        {
                            continue;
                        }

                        if (knownTypes.IsNonGenericComparable(memberType))
                        {
                            continue;
                        }

                        if (knownTypes.IsStructuralComparable(memberType))
                        {
                            continue;
                        }

                        if (typeMap.TryGetValue(memberType, out var targetType) &&
                            targetType.HasComparableAttribute)
                        {
                            continue;
                        }

                        if (!knownTypes.TryGetNullableUnderlyingType(memberType, out var underlyingType))
                        {
                            continue;
                        }

                        if (typeMap.TryGetValue(underlyingType, out var targetType2) &&
                            targetType2.HasComparableAttribute)
                        {
                            continue;
                        }

                        var memberLocation = member.Symbol.DeclaringSyntaxReferences
                                             .Select(x => Location.Create(x.SyntaxTree, x.Span))
                                             .ToArray();

                        context.ReportDiagnostic(
                            DiagnosticFactory.MemberIsNotComparable(
                                candidateType.FullName,
                                member.Name,
                                member.TypeName,
                                memberLocation[0],
                                memberLocation.Skip(1)));
                    }
                }
            }

            foreach (var sourceType in candidateTypes)
            {
                var    options  = sourceType.GenerateOptions;
                string fullName = sourceType.FullName;

                GenerateCode(
                    context,
                    new CommonGenerator(sourceType),
                    $"{fullName}_Common.cs");

                if (options.GenerateEqualityContract &&
                    !sourceType.HasEqualityContract &&
                    !sourceType.IsValueType)
                {
                    GenerateCode(
                        context,
                        new EqualityContractGenerator(sourceType),
                        $"{fullName}_EqualityContract.cs");
                }

                if (options.GenerateEquatable &&
                    !sourceType.IsEquatable)
                {
                    GenerateCode(
                        context,
                        new EquatableGenerator(sourceType),
                        $"{fullName}_Equatable.cs");
                }

                if (options.GenerateGenericComparable &&
                    !sourceType.IsGenericComparable)
                {
                    GenerateCode(
                        context,
                        new GenericComparableGenerator(sourceType),
                        $"{fullName}_GenericComparable.cs");
                }

                if (options.GenerateNonGenericComparable &&
                    !sourceType.IsNonGenericComparable)
                {
                    GenerateCode(
                        context,
                        new NonGenericComparableGenerator(sourceType),
                        $"{fullName}_NonGenericComparable.cs");
                }

                if (options.OverrideObjectMethods)
                {
                    if (!sourceType.OverridesObjectEquals)
                    {
                        GenerateCode(
                            context,
                            new ObjectEqualsGenerator(sourceType),
                            $"{fullName}_ObjectEquals.cs");
                    }

                    if (!sourceType.OverridesObjectGetHashCode)
                    {
                        GenerateCode(
                            context,
                            new ObjectGetHashCodeGenerator(sourceType),
                            $"{fullName}_ObjectGetHashCode.cs");
                    }
                }

                if (options.GenerateOperators &&
                    !sourceType.DefinedNullableParameterOperator)
                {
                    GenerateCode(
                        context,
                        new EqualityOperatorsGenerator(sourceType),
                        $"{fullName}_EqualityOperators.cs");

                    if (sourceType.HasComparableAttribute)
                    {
                        GenerateCode(
                            context,
                            new ComparisonOperatorsGenerator(sourceType),
                            $"{fullName}_ComparisonOperators.cs");
                    }
                }
            }
        }
Example #3
0
        public BaseTypeInfo(
            INamedTypeSymbol typeSymbol,
            KnownTypes knownTypes)
        {
            if (typeSymbol is null)
            {
                throw new ArgumentNullException(nameof(typeSymbol));
            }

            if (knownTypes is null)
            {
                throw new ArgumentNullException(nameof(knownTypes));
            }

            var comparer = SymbolEqualityComparer.Default;

            this.TypeSymbol = typeSymbol;
            this.FullName   = typeSymbol.GetFullName();

            this.IsEquatable            = knownTypes.IsEquatable(typeSymbol);
            this.IsGenericComparable    = knownTypes.IsGenericComparable(typeSymbol);
            this.IsNonGenericComparable = knownTypes.IsNonGenericComparable(typeSymbol);
            this.IsStructuralEquatable  = knownTypes.IsStructuralEquatable(typeSymbol);
            this.IsStructuralComparable = knownTypes.IsStructuralComparable(typeSymbol);

            var objectEquals = knownTypes.Object.GetMembers(nameof(object.Equals))
                               .OfType <IMethodSymbol>()
                               .Single(x =>
                                       x.Parameters.Length == 1 &&
                                       comparer.Equals(x.Parameters[0].Type, knownTypes.Object));

            var objectEqualsOverride =
                typeSymbol.GetOverrideSymbol(objectEquals !, comparer);

            this.OverridesObjectEquals = objectEqualsOverride is not null;

            var objectGetHashCode = knownTypes.Object.GetMembers(nameof(object.GetHashCode))
                                    .OfType <IMethodSymbol>()
                                    .Single(x => x.Parameters.Length == 0);

            var objectGetHashCodeOverride =
                typeSymbol.GetOverrideSymbol(objectGetHashCode !, comparer);

            this.OverridesObjectGetHashCode = objectGetHashCodeOverride is not null;

            var operators =
                typeSymbol.GetMembers()
                .OfType <IMethodSymbol>()
                .Where(x =>
                       x.MethodKind == MethodKind.UserDefinedOperator &&
                       x.Parameters.Length == 2);

            var nullableType = typeSymbol;

            if (typeSymbol.IsValueType)
            {
                nullableType = knownTypes.MakeNullable(typeSymbol);
            }

            var operatorNames = new[]
            {
                "op_Equality",
                "op_Inequality",
                "op_LessThan",
                "op_GreaterThan",
                "op_LessThanOrEqual",
                "op_GreaterThanOrEqual"
            };

            foreach (var op in operators)
            {
                if (!operatorNames.Contains(op.Name, StringComparer.Ordinal))
                {
                    continue;
                }

                var type1 = op.Parameters[0].Type;
                var type2 = op.Parameters[1].Type;

                if (!this.DefinedNullableParameterOperator)
                {
                    bool matchTypes =
                        comparer.Equals(type1, nullableType) &&
                        comparer.Equals(type2, nullableType);

                    this.DefinedNullableParameterOperator = matchTypes;
                }

                if (typeSymbol.IsValueType &&
                    !this.DefinedNonNullableParameterOperator)
                {
                    bool matchTypes =
                        comparer.Equals(type1, typeSymbol) &&
                        comparer.Equals(type2, typeSymbol);

                    this.DefinedNonNullableParameterOperator = matchTypes;
                }
            }

            bool hasEqualityContract = typeSymbol.GetMembers("EqualityContract")
                                       .OfType <IPropertySymbol>()
                                       .Any(x =>
                                            x.IsVirtual &&
                                            comparer.Equals(x.Type, knownTypes.Type));

            this.HasEqualityContract = hasEqualityContract;
        }