public CompilationContext(Compilation compilation) { Compilation = compilation; var generatorAssembly = GetType().Assembly; var generatorAssemblyName = generatorAssembly.GetName().Name; var generatorAssemblyVersion = AttributeExtensions.GetCustomAttribute <System.Reflection.AssemblyInformationalVersionAttribute>(generatorAssembly)?.InformationalVersion ?? string.Empty; GeneratedCodeAttribute = $"[GeneratedCode(\"{generatorAssemblyName}\", \"{generatorAssemblyVersion}\")]"; GeneratorIgnoreAttribute = compilation.GetTypeByMetadataName(typeof(GeneratorIgnoreAttribute).FullName) ?? throw new Exception("GeneratorIgnoreAttribute symbol not found!"); GeneratorBindingsAttribute = compilation.GetTypeByMetadataName(typeof(GeneratorBindingsAttribute).FullName) ?? throw new Exception("GeneratorBindingsAttribute symbol not found!"); GeneratorMappingAttribute = compilation.GetTypeByMetadataName(typeof(GeneratorMappingAttribute).FullName) ?? throw new Exception("GeneratorMappingAttribute symbol not found!"); }
/// <summary> /// Generates the source for the overloads based on the defined extension methods. /// </summary> /// <param name="context"></param> /// <param name="collectedExtensionMethods">A dictionary containing the defined extension methods.</param> /// <param name="generatedPath">The path where to serialize the generated code for debugging.</param> void GenerateSource(GeneratorExecutionContext context, ImmutableDictionary <string, List <IMethodSymbol> > collectedExtensionMethods, string?generatedPath) { // go through all public static types // the enumerables are defined inside of these foreach (var containerClass in context.Compilation.SourceModule.GlobalNamespace .GetAllTypes() .Where(type => type.IsStatic && type.IsReferenceType && type.IsPublic() && !type.ShouldIgnore(context.Compilation))) { // cache a GeneratedCodeAttribute string to use on all generated methods var generatorAssembly = GetType().Assembly; var generatorAssemblyName = generatorAssembly.GetName().Name; var generatorAssemblyVersion = AttributeExtensions.GetCustomAttribute <System.Reflection.AssemblyInformationalVersionAttribute>(generatorAssembly)?.InformationalVersion ?? string.Empty; var generatedCodeAttribute = $"[GeneratedCode(\"{generatorAssemblyName}\", \"{generatorAssemblyVersion}\")]"; // get the inner instance types // these can be enumerables foreach (var extendedType in containerClass.GetTypeMembers() .OfType <INamedTypeSymbol>() .Where(type => !(type.IsStatic || type.IsInterface()))) { // check if it's a value enumerable and keep a reference to the implemented interface var valueEnumerableInterface = extendedType.GetAllInterfaces() .FirstOrDefault(@interface => @interface.Name == "IValueEnumerable" || @interface.Name == "IAsyncValueEnumerable"); if (valueEnumerableInterface is null) { continue; } // get the typed of the enumerable, enumerator and source from the generic parameters declaration var enumerableType = extendedType; var enumeratorType = valueEnumerableInterface.TypeArguments[1]; var sourceType = valueEnumerableInterface.TypeArguments[0]; // get the type mappings from the GeneratorMappingsAttribute, if found. var genericsMapping = ImmutableArray.CreateRange(extendedType.GetGenericMappings(context.Compilation)); // get the name and parameter list of all the instance methods declared in this type var implementedInstanceMethods = extendedType.GetMembers().OfType <IMethodSymbol>() .Select(method => Tuple.Create( method.Name, ImmutableArray.CreateRange(method.Parameters .Select(parameter => parameter.Type.ToDisplayString())))); // get the extension methods for this type declared in the outter static type var implementedExtensionMethods = containerClass.GetMembers().OfType <IMethodSymbol>() .Where(method => method.IsExtensionMethod && method.Parameters[0].Type.ToDisplayString() == extendedType.ToDisplayString()) .Select(method => Tuple.Create( method.Name, ImmutableArray.CreateRange(method.Parameters .Skip(1) .Select(parameter => parameter.Type.ToDisplayString())))); // join the two lists together as these are the implemented methods for this type // the generated methods will be added to this list var implementedMethods = implementedInstanceMethods.Concat(implementedExtensionMethods) .ToList(); var instanceMethods = new List <IMethodSymbol>(); var extensionMethods = new List <IMethodSymbol>(); // go through all the implemented interfaces so that // the overloads are generated based on the extension methods defined for these var extendedTypeInterfaces = extendedType.AllInterfaces; for (var interfaceIndex = 0; interfaceIndex < extendedTypeInterfaces.Length; interfaceIndex++) { var implementedInterfaceType = extendedTypeInterfaces[interfaceIndex]; // get the extension methods collected for this interface var key = implementedInterfaceType.OriginalDefinition.MetadataName; if (!collectedExtensionMethods.TryGetValue(key, out var implementedTypeMethods)) { continue; } // check which ones should be generated // the method can be already defined by a more performant custom implementation for (var methodIndex = 0; methodIndex < implementedTypeMethods.Count; methodIndex++) { var implementedTypeMethod = implementedTypeMethods[methodIndex]; var methodName = implementedTypeMethod.Name; var methodParameters = ImmutableArray.CreateRange(implementedTypeMethod.Parameters .Skip(1) .Select(parameter => parameter.Type.ToDisplayString(genericsMapping))); // check if already implemented if (!implementedMethods.Any(method => method.Item1 == methodName && method.Item2.SequenceEqual(methodParameters))) { // check if there's a collision with a property if (extendedType.GetMembers().OfType <IPropertySymbol>() .Any(property => property.Name == methodName)) { // this method will be generated as an extension method extensionMethods.Add(implementedTypeMethod); } else { // this method will generated as an instance method instanceMethods.Add(implementedTypeMethod); } // add to the implemented methods collection implementedMethods.Add(Tuple.Create(methodName, methodParameters)); } } } // generate the code for the instance methods and extension methods, if any... if (instanceMethods.Count != 0 || extensionMethods.Count != 0) { using var builder = new CodeBuilder(); builder.AppendLine("using System;"); builder.AppendLine("using System.CodeDom.Compiler;"); builder.AppendLine("using System.Diagnostics;"); builder.AppendLine("using System.Runtime.CompilerServices;"); builder.AppendLine(); using (builder.AppendBlock($"namespace NetFabric.Hyperlinq")) { // the generator extends the types by adding partial types // both the outter and the inner types have to be declared as partial using (builder.AppendBlock($"public static partial class {containerClass.Name}")) { // generate the instance methods in the inner type if (instanceMethods.Count != 0) { var extendedTypeGenericParameters = string.Empty; if (extendedType.IsGenericType) { var parametersDefinition = new StringBuilder(); _ = parametersDefinition.Append($"<{extendedType.TypeParameters.Select(parameter => parameter.ToDisplayString()).ToCommaSeparated()}>"); foreach (var typeParameter in extendedType.TypeParameters.Where(typeParameter => typeParameter.ConstraintTypes.Length != 0)) { _ = parametersDefinition.Append($" where {typeParameter.Name} : {typeParameter.AsConstraintsStrings().ToCommaSeparated()}"); } extendedTypeGenericParameters = parametersDefinition.ToString(); } var entity = extendedType.IsValueType ? "readonly partial struct" // it's a value type : "partial class"; // it's a reference type using (builder.AppendBlock($"public {entity} {extendedType.Name}{extendedTypeGenericParameters}")) { foreach (var instanceMethod in instanceMethods) { GenerateInstanceMethod(builder, extendedType, instanceMethod, enumerableType, enumeratorType, generatedCodeAttribute, genericsMapping); } } } builder.AppendLine(); // generate the extension methods in the outter type foreach (var extensionMethod in extensionMethods) { GenerateExtensionMethod(builder, extendedType, extensionMethod, enumerableType, enumeratorType, generatedCodeAttribute, genericsMapping); } } } var hitName = $"{containerClass.OriginalDefinition.MetadataName}.{extendedType.OriginalDefinition.MetadataName}.cs"; var source = builder.ToString(); if (generatedPath is object) { File.WriteAllText(Path.Combine(generatedPath, hitName), source); } hitName = hitName.Replace('`', '.'); context.AddSource(hitName, SourceText.From(source, Encoding.UTF8)); } } } }