/// <summary> /// Gets the sequence of processed discovered custom types. /// </summary> /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param> /// <param name="structDeclarationSymbol">The type symbol for the shader type.</param> /// <param name="types">The sequence of discovered custom types.</param> /// <returns>A sequence of custom type definitions to add to the shader source.</returns> private static ImmutableArray <(string Name, string Definition)> GetDeclaredTypes( ImmutableArray <Diagnostic> .Builder diagnostics, INamedTypeSymbol structDeclarationSymbol, IEnumerable <INamedTypeSymbol> types) { ImmutableArray <(string, string)> .Builder builder = ImmutableArray.CreateBuilder <(string, string)>(); IReadOnlyCollection <INamedTypeSymbol> invalidTypes; // Process the discovered types foreach (var type in HlslKnownTypes.GetCustomTypes(types, out invalidTypes)) { var structType = type.GetFullMetadataName().ToHlslIdentifierName(); var structDeclaration = StructDeclaration(structType); // Declare the fields of the current type foreach (var field in type.GetMembers().OfType <IFieldSymbol>()) { if (field.IsStatic) { continue; } INamedTypeSymbol fieldType = (INamedTypeSymbol)field.Type; // Convert the name to the fully qualified HLSL version if (!HlslKnownTypes.TryGetMappedName(fieldType.GetFullMetadataName(), out string?mappedType)) { mappedType = fieldType.GetFullMetadataName().ToHlslIdentifierName(); } // Get the field name as a valid HLSL identifier if (!HlslKnownKeywords.TryGetMappedName(field.Name, out string?mappedName)) { mappedName = field.Name; } structDeclaration = structDeclaration.AddMembers( FieldDeclaration(VariableDeclaration( IdentifierName(mappedType !)).AddVariables( VariableDeclarator(Identifier(mappedName !))))); } // Insert the trailing ; right after the closing bracket (after normalization) builder.Add(( structType, structDeclaration .NormalizeWhitespace(eol: "\n") .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .ToFullString())); } // Process the invalid types foreach (INamedTypeSymbol invalidType in invalidTypes) { diagnostics.Add(InvalidDiscoveredType, structDeclarationSymbol, structDeclarationSymbol, invalidType); } return(builder.ToImmutable()); }
/// <summary> /// Gets a sequence of shader static fields and their mapped names. /// </summary> /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param> /// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param> /// <param name="structDeclaration">The <see cref="StructDeclarationSyntax"/> instance for the current type.</param> /// <param name="structDeclarationSymbol">The type symbol for the shader type.</param> /// <param name="discoveredTypes">The collection of currently discovered types.</param> /// <param name="constantDefinitions">The collection of discovered constant definitions.</param> /// <returns>A sequence of static constant fields in <paramref name="structDeclarationSymbol"/>.</returns> private static ImmutableArray <(string Name, string TypeDeclaration, string?Assignment)> GetStaticFields( ImmutableArray <Diagnostic> .Builder diagnostics, SemanticModelProvider semanticModel, StructDeclarationSyntax structDeclaration, INamedTypeSymbol structDeclarationSymbol, ICollection <INamedTypeSymbol> discoveredTypes, IDictionary <IFieldSymbol, string> constantDefinitions) { ImmutableArray <(string, string, string?)> .Builder builder = ImmutableArray.CreateBuilder <(string, string, string?)>(); foreach (var fieldDeclaration in structDeclaration.Members.OfType <FieldDeclarationSyntax>()) { foreach (var variableDeclarator in fieldDeclaration.Declaration.Variables) { IFieldSymbol fieldSymbol = (IFieldSymbol)semanticModel.For(variableDeclarator).GetDeclaredSymbol(variableDeclarator) !; if (!fieldSymbol.IsStatic || fieldSymbol.IsConst) { continue; } // Constant properties must be of a primitive, vector or matrix type if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol || !HlslKnownTypes.IsKnownHlslType(typeSymbol.GetFullMetadataName())) { diagnostics.Add(InvalidShaderStaticFieldType, variableDeclarator, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type); continue; } _ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string?mapping); string typeDeclaration = fieldSymbol.IsReadOnly switch { true => $"static const {HlslKnownTypes.GetMappedName(typeSymbol)}", false => $"static {HlslKnownTypes.GetMappedName(typeSymbol)}" }; StaticFieldRewriter staticFieldRewriter = new( semanticModel, discoveredTypes, constantDefinitions, diagnostics); string?assignment = staticFieldRewriter.Visit(variableDeclarator)?.NormalizeWhitespace(eol: "\n").ToFullString(); builder.Add((mapping ?? fieldSymbol.Name, typeDeclaration, assignment)); } } return(builder.ToImmutable()); }
/// <summary> /// Loads additional static methods used by the shader /// </summary> /// <param name="name">The HLSL name of the new method to load</param> /// <param name="methodInfo">The <see cref="MethodInfo"/> instance for the method to load</param> private void LoadStaticMethodSource(string name, MethodInfo methodInfo) { // Decompile the target method MethodDecompiler.Instance.GetSyntaxTree(methodInfo, MethodType.Static, out MethodDeclarationSyntax root, out SemanticModel semanticModel); // Rewrite the method ShaderSyntaxRewriter syntaxRewriter = new ShaderSyntaxRewriter(semanticModel, methodInfo.DeclaringType); root = (MethodDeclarationSyntax)syntaxRewriter.Visit(root); // Register the captured static members foreach (var member in syntaxRewriter.StaticMembers) { LoadFieldInfo(member.Value, member.Key); } // Register the captured static methods foreach (var method in syntaxRewriter.StaticMethods) { LoadStaticMethodSource(method.Key, method.Value); } // Get the function parameters IReadOnlyList <ParameterInfo> parameters = ( from parameter in root.ParameterList.Parameters.Select((p, i) => (Node: p, Index: i)) let modifiers = parameter.Node.Modifiers let type = parameter.Node.Type.ToFullString() let parameterName = parameter.Node.Identifier.ToFullString() let last = parameter.Index == root.ParameterList.Parameters.Count - 1 select new ParameterInfo(modifiers, type, parameterName, last)).ToArray(); // Get the function body string body = root.Body.ToFullString(); body = Regex.Replace(body, @"(?<=\W)(\d+)[fFdD]", m => m.Groups[1].Value); body = body.TrimEnd('\n', '\r', ' '); body = HlslKnownKeywords.GetMappedText(body); // Get the final function info instance FunctionInfo functionInfo = new FunctionInfo( methodInfo.ReturnType, $"{methodInfo.DeclaringType.FullName}{Type.Delimiter}{methodInfo.Name}", string.Join(", ", methodInfo.GetParameters().Select(p => $"{p.ParameterType.ToFriendlyString()} {p.Name}")), root.ReturnType.ToString(), name, parameters, body); _FunctionsList.Add(functionInfo); }
/// <summary> /// Gets a sequence of captured fields and their mapped names. /// </summary> /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param> /// <param name="structDeclarationSymbol">The input <see cref="INamedTypeSymbol"/> instance to process.</param> /// <param name="types">The collection of currently discovered types.</param> /// <returns>A sequence of captured fields in <paramref name="structDeclarationSymbol"/>.</returns> private static ImmutableArray <(string Name, string HlslType)> GetInstanceFields( ImmutableArray <Diagnostic> .Builder diagnostics, INamedTypeSymbol structDeclarationSymbol, ICollection <INamedTypeSymbol> types) { ImmutableArray <(string, string)> .Builder values = ImmutableArray.CreateBuilder <(string, string)>(); foreach (var fieldSymbol in structDeclarationSymbol.GetMembers().OfType <IFieldSymbol>()) { if (fieldSymbol.IsStatic) { continue; } // Captured fields must be named type symbols if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol) { diagnostics.Add(InvalidShaderField, fieldSymbol, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type); continue; } string metadataName = typeSymbol.GetFullMetadataName(); string typeName = HlslKnownTypes.GetMappedName(typeSymbol); _ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string?mapping); // Allowed fields must be unmanaged values if (typeSymbol.IsUnmanagedType) { // Track the type if it's a custom struct if (!HlslKnownTypes.IsKnownHlslType(metadataName)) { types.Add(typeSymbol); } values.Add((mapping ?? fieldSymbol.Name, typeName)); } else { diagnostics.Add(InvalidShaderField, fieldSymbol, structDeclarationSymbol, fieldSymbol.Name, typeSymbol); } } return(values.ToImmutable()); }
/// <summary> /// Loads the entry method for the current shader being loaded /// </summary> private void LoadMethodSource() { // Decompile the shader method MethodDecompiler.Instance.GetSyntaxTree(Action.Method, MethodType.Closure, out MethodDeclarationSyntax root, out SemanticModel semanticModel); // Rewrite the shader method (eg. to fix the type declarations) ShaderSyntaxRewriter syntaxRewriter = new ShaderSyntaxRewriter(semanticModel, ShaderType); root = (MethodDeclarationSyntax)syntaxRewriter.Visit(root); // Extract the implicit local functions var locals = root.DescendantNodes().OfType <LocalFunctionStatementSyntax>().ToArray(); root = root.RemoveNodes(locals, SyntaxRemoveOptions.KeepNoTrivia); foreach (var local in locals) { string alignedLocal = local.ToFullString().RemoveLeftPadding().Trim(' ', '\r', '\n'); alignedLocal = Regex.Replace(alignedLocal, @"(?<=\W)(\d+)[fFdD]", m => m.Groups[1].Value); alignedLocal = HlslKnownKeywords.GetMappedText(alignedLocal); _LocalFunctionsList.Add(new LocalFunctionInfo(alignedLocal)); } // Register the captured static members foreach (var member in syntaxRewriter.StaticMembers) { LoadFieldInfo(member.Value, member.Key); } // Register the captured static methods foreach (var method in syntaxRewriter.StaticMethods) { LoadStaticMethodSource(method.Key, method.Value); } // Get the thread ids identifier name and shader method body ThreadsIdsVariableName = root.ParameterList.Parameters.First().Identifier.Text; MethodBody = root.Body.ToFullString(); // Additional preprocessing MethodBody = Regex.Replace(MethodBody, @"(?<=\W)(\d+)[fFdD]", m => m.Groups[1].Value); MethodBody = MethodBody.TrimEnd('\n', '\r', ' '); MethodBody = HlslKnownKeywords.GetMappedText(MethodBody); }
/// <summary> /// Loads a specified <see cref="ReadableMember"/> and adds it to the shader model /// </summary> /// <param name="memberInfo">The target <see cref="ReadableMember"/> to load</param> /// <param name="name">The optional explicit name to use for the field</param> /// <param name="parents">The list of parent fields to reach the current <see cref="ReadableMember"/> from a given <see cref="Action{T}"/></param> private void LoadFieldInfo(ReadableMember memberInfo, string?name = null, IReadOnlyList <ReadableMember>?parents = null) { Type fieldType = memberInfo.MemberType; string fieldName = HlslKnownKeywords.GetMappedName(name ?? memberInfo.Name); // Constant buffer if (HlslKnownTypes.IsConstantBufferType(fieldType)) { DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.ConstantBufferView, 1, _ConstantBuffersCount)); // Track the buffer field memberInfo.Parents = parents; _CapturedMembers.Add(memberInfo); string typeName = HlslKnownTypes.GetMappedName(fieldType.GenericTypeArguments[0]); _BuffersList.Add(new ConstantBufferFieldInfo(fieldType, typeName, fieldName, _ConstantBuffersCount++)); } else if (HlslKnownTypes.IsReadOnlyBufferType(fieldType)) { // Root parameter for a readonly buffer DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.ShaderResourceView, 1, _ReadOnlyBuffersCount)); // Track the buffer field memberInfo.Parents = parents; _CapturedMembers.Add(memberInfo); string typeName = HlslKnownTypes.GetMappedName(fieldType); _BuffersList.Add(new ReadOnlyBufferFieldInfo(fieldType, typeName, fieldName, _ReadOnlyBuffersCount++)); } else if (HlslKnownTypes.IsReadWriteBufferType(fieldType)) { // Root parameter for a read write buffer DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.UnorderedAccessView, 1, _ReadWriteBuffersCount)); // Track the buffer field memberInfo.Parents = parents; _CapturedMembers.Add(memberInfo); string typeName = HlslKnownTypes.GetMappedName(fieldType); _BuffersList.Add(new ReadWriteBufferFieldInfo(fieldType, typeName, fieldName, _ReadWriteBuffersCount++)); } else if (HlslKnownTypes.IsKnownScalarType(fieldType) || HlslKnownTypes.IsKnownVectorType(fieldType)) { // Register the captured field memberInfo.Parents = parents; _CapturedMembers.Add(memberInfo); string typeName = HlslKnownTypes.GetMappedName(fieldType); _FieldsList.Add(new CapturedFieldInfo(fieldType, typeName, fieldName)); } else if (fieldType.IsClass && fieldName.StartsWith("CS$<>")) { // Captured scope, update the parents list List <ReadableMember> updatedParents = parents?.ToList() ?? new List <ReadableMember>(); updatedParents.Add(memberInfo); // Recurse on the new compiler generated class IReadOnlyList <FieldInfo> fields = fieldType.GetFields().ToArray(); foreach (FieldInfo fieldInfo in fields) { LoadFieldInfo(fieldInfo, null, updatedParents); } } else if (fieldType.IsDelegate() && memberInfo.GetValue(Action.Target) is Delegate func && (func.Method.IsStatic || func.Method.DeclaringType.IsStatelessDelegateContainer()) && (HlslKnownTypes.IsKnownScalarType(func.Method.ReturnType) || HlslKnownTypes.IsKnownVectorType(func.Method.ReturnType)) && fieldType.GenericTypeArguments.All(type => HlslKnownTypes.IsKnownScalarType(type) || HlslKnownTypes.IsKnownVectorType(type))) { // Captured static delegates with a return type LoadStaticMethodSource(fieldName, func.Method); } }