예제 #1
0
        /// <summary>
        /// Gets the sequence of processed discovered custom types.
        /// </summary>
        /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param>
        /// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
        /// <param name="types">The sequence of discovered custom types.</param>
        /// <returns>A sequence of custom type definitions to add to the shader source.</returns>
        private static ImmutableArray <(string Name, string Definition)> GetDeclaredTypes(
            ImmutableArray <Diagnostic> .Builder diagnostics,
            INamedTypeSymbol structDeclarationSymbol,
            IEnumerable <INamedTypeSymbol> types)
        {
            ImmutableArray <(string, string)> .Builder builder = ImmutableArray.CreateBuilder <(string, string)>();
            IReadOnlyCollection <INamedTypeSymbol>     invalidTypes;

            // Process the discovered types
            foreach (var type in HlslKnownTypes.GetCustomTypes(types, out invalidTypes))
            {
                var structType        = type.GetFullMetadataName().ToHlslIdentifierName();
                var structDeclaration = StructDeclaration(structType);

                // Declare the fields of the current type
                foreach (var field in type.GetMembers().OfType <IFieldSymbol>())
                {
                    if (field.IsStatic)
                    {
                        continue;
                    }

                    INamedTypeSymbol fieldType = (INamedTypeSymbol)field.Type;

                    // Convert the name to the fully qualified HLSL version
                    if (!HlslKnownTypes.TryGetMappedName(fieldType.GetFullMetadataName(), out string?mappedType))
                    {
                        mappedType = fieldType.GetFullMetadataName().ToHlslIdentifierName();
                    }

                    // Get the field name as a valid HLSL identifier
                    if (!HlslKnownKeywords.TryGetMappedName(field.Name, out string?mappedName))
                    {
                        mappedName = field.Name;
                    }

                    structDeclaration = structDeclaration.AddMembers(
                        FieldDeclaration(VariableDeclaration(
                                             IdentifierName(mappedType !)).AddVariables(
                                             VariableDeclarator(Identifier(mappedName !)))));
                }

                // Insert the trailing ; right after the closing bracket (after normalization)
                builder.Add((
                                structType,
                                structDeclaration
                                .NormalizeWhitespace(eol: "\n")
                                .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
                                .ToFullString()));
            }

            // Process the invalid types
            foreach (INamedTypeSymbol invalidType in invalidTypes)
            {
                diagnostics.Add(InvalidDiscoveredType, structDeclarationSymbol, structDeclarationSymbol, invalidType);
            }

            return(builder.ToImmutable());
        }
예제 #2
0
        /// <summary>
        /// Gets a sequence of shader static fields and their mapped names.
        /// </summary>
        /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param>
        /// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
        /// <param name="structDeclaration">The <see cref="StructDeclarationSyntax"/> instance for the current type.</param>
        /// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
        /// <param name="discoveredTypes">The collection of currently discovered types.</param>
        /// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
        /// <returns>A sequence of static constant fields in <paramref name="structDeclarationSymbol"/>.</returns>
        private static ImmutableArray <(string Name, string TypeDeclaration, string?Assignment)> GetStaticFields(
            ImmutableArray <Diagnostic> .Builder diagnostics,
            SemanticModelProvider semanticModel,
            StructDeclarationSyntax structDeclaration,
            INamedTypeSymbol structDeclarationSymbol,
            ICollection <INamedTypeSymbol> discoveredTypes,
            IDictionary <IFieldSymbol, string> constantDefinitions)
        {
            ImmutableArray <(string, string, string?)> .Builder builder = ImmutableArray.CreateBuilder <(string, string, string?)>();

            foreach (var fieldDeclaration in structDeclaration.Members.OfType <FieldDeclarationSyntax>())
            {
                foreach (var variableDeclarator in fieldDeclaration.Declaration.Variables)
                {
                    IFieldSymbol fieldSymbol = (IFieldSymbol)semanticModel.For(variableDeclarator).GetDeclaredSymbol(variableDeclarator) !;

                    if (!fieldSymbol.IsStatic || fieldSymbol.IsConst)
                    {
                        continue;
                    }

                    // Constant properties must be of a primitive, vector or matrix type
                    if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol ||
                        !HlslKnownTypes.IsKnownHlslType(typeSymbol.GetFullMetadataName()))
                    {
                        diagnostics.Add(InvalidShaderStaticFieldType, variableDeclarator, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type);

                        continue;
                    }

                    _ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string?mapping);

                    string typeDeclaration = fieldSymbol.IsReadOnly switch
                    {
                        true => $"static const {HlslKnownTypes.GetMappedName(typeSymbol)}",
                        false => $"static {HlslKnownTypes.GetMappedName(typeSymbol)}"
                    };

                    StaticFieldRewriter staticFieldRewriter = new(
                        semanticModel,
                        discoveredTypes,
                        constantDefinitions,
                        diagnostics);

                    string?assignment = staticFieldRewriter.Visit(variableDeclarator)?.NormalizeWhitespace(eol: "\n").ToFullString();

                    builder.Add((mapping ?? fieldSymbol.Name, typeDeclaration, assignment));
                }
            }

            return(builder.ToImmutable());
        }
예제 #3
0
        /// <summary>
        /// Loads additional static methods used by the shader
        /// </summary>
        /// <param name="name">The HLSL name of the new method to load</param>
        /// <param name="methodInfo">The <see cref="MethodInfo"/> instance for the method to load</param>
        private void LoadStaticMethodSource(string name, MethodInfo methodInfo)
        {
            // Decompile the target method
            MethodDecompiler.Instance.GetSyntaxTree(methodInfo, MethodType.Static, out MethodDeclarationSyntax root, out SemanticModel semanticModel);

            // Rewrite the method
            ShaderSyntaxRewriter syntaxRewriter = new ShaderSyntaxRewriter(semanticModel, methodInfo.DeclaringType);

            root = (MethodDeclarationSyntax)syntaxRewriter.Visit(root);

            // Register the captured static members
            foreach (var member in syntaxRewriter.StaticMembers)
            {
                LoadFieldInfo(member.Value, member.Key);
            }

            // Register the captured static methods
            foreach (var method in syntaxRewriter.StaticMethods)
            {
                LoadStaticMethodSource(method.Key, method.Value);
            }

            // Get the function parameters
            IReadOnlyList <ParameterInfo> parameters = (
                from parameter in root.ParameterList.Parameters.Select((p, i) => (Node: p, Index: i))
                let modifiers = parameter.Node.Modifiers
                                let type = parameter.Node.Type.ToFullString()
                                           let parameterName = parameter.Node.Identifier.ToFullString()
                                                               let last = parameter.Index == root.ParameterList.Parameters.Count - 1
                                                                          select new ParameterInfo(modifiers, type, parameterName, last)).ToArray();

            // Get the function body
            string body = root.Body.ToFullString();

            body = Regex.Replace(body, @"(?<=\W)(\d+)[fFdD]", m => m.Groups[1].Value);
            body = body.TrimEnd('\n', '\r', ' ');
            body = HlslKnownKeywords.GetMappedText(body);

            // Get the final function info instance
            FunctionInfo functionInfo = new FunctionInfo(
                methodInfo.ReturnType,
                $"{methodInfo.DeclaringType.FullName}{Type.Delimiter}{methodInfo.Name}",
                string.Join(", ", methodInfo.GetParameters().Select(p => $"{p.ParameterType.ToFriendlyString()} {p.Name}")),
                root.ReturnType.ToString(),
                name,
                parameters,
                body);

            _FunctionsList.Add(functionInfo);
        }
예제 #4
0
        /// <summary>
        /// Gets a sequence of captured fields and their mapped names.
        /// </summary>
        /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param>
        /// <param name="structDeclarationSymbol">The input <see cref="INamedTypeSymbol"/> instance to process.</param>
        /// <param name="types">The collection of currently discovered types.</param>
        /// <returns>A sequence of captured fields in <paramref name="structDeclarationSymbol"/>.</returns>
        private static ImmutableArray <(string Name, string HlslType)> GetInstanceFields(
            ImmutableArray <Diagnostic> .Builder diagnostics,
            INamedTypeSymbol structDeclarationSymbol,
            ICollection <INamedTypeSymbol> types)
        {
            ImmutableArray <(string, string)> .Builder values = ImmutableArray.CreateBuilder <(string, string)>();

            foreach (var fieldSymbol in structDeclarationSymbol.GetMembers().OfType <IFieldSymbol>())
            {
                if (fieldSymbol.IsStatic)
                {
                    continue;
                }

                // Captured fields must be named type symbols
                if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol)
                {
                    diagnostics.Add(InvalidShaderField, fieldSymbol, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type);

                    continue;
                }

                string metadataName = typeSymbol.GetFullMetadataName();
                string typeName     = HlslKnownTypes.GetMappedName(typeSymbol);

                _ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string?mapping);

                // Allowed fields must be unmanaged values
                if (typeSymbol.IsUnmanagedType)
                {
                    // Track the type if it's a custom struct
                    if (!HlslKnownTypes.IsKnownHlslType(metadataName))
                    {
                        types.Add(typeSymbol);
                    }

                    values.Add((mapping ?? fieldSymbol.Name, typeName));
                }
                else
                {
                    diagnostics.Add(InvalidShaderField, fieldSymbol, structDeclarationSymbol, fieldSymbol.Name, typeSymbol);
                }
            }

            return(values.ToImmutable());
        }
예제 #5
0
        /// <summary>
        /// Loads the entry method for the current shader being loaded
        /// </summary>
        private void LoadMethodSource()
        {
            // Decompile the shader method
            MethodDecompiler.Instance.GetSyntaxTree(Action.Method, MethodType.Closure, out MethodDeclarationSyntax root, out SemanticModel semanticModel);

            // Rewrite the shader method (eg. to fix the type declarations)
            ShaderSyntaxRewriter syntaxRewriter = new ShaderSyntaxRewriter(semanticModel, ShaderType);

            root = (MethodDeclarationSyntax)syntaxRewriter.Visit(root);

            // Extract the implicit local functions
            var locals = root.DescendantNodes().OfType <LocalFunctionStatementSyntax>().ToArray();

            root = root.RemoveNodes(locals, SyntaxRemoveOptions.KeepNoTrivia);
            foreach (var local in locals)
            {
                string alignedLocal = local.ToFullString().RemoveLeftPadding().Trim(' ', '\r', '\n');
                alignedLocal = Regex.Replace(alignedLocal, @"(?<=\W)(\d+)[fFdD]", m => m.Groups[1].Value);
                alignedLocal = HlslKnownKeywords.GetMappedText(alignedLocal);

                _LocalFunctionsList.Add(new LocalFunctionInfo(alignedLocal));
            }

            // Register the captured static members
            foreach (var member in syntaxRewriter.StaticMembers)
            {
                LoadFieldInfo(member.Value, member.Key);
            }

            // Register the captured static methods
            foreach (var method in syntaxRewriter.StaticMethods)
            {
                LoadStaticMethodSource(method.Key, method.Value);
            }

            // Get the thread ids identifier name and shader method body
            ThreadsIdsVariableName = root.ParameterList.Parameters.First().Identifier.Text;
            MethodBody             = root.Body.ToFullString();

            // Additional preprocessing
            MethodBody = Regex.Replace(MethodBody, @"(?<=\W)(\d+)[fFdD]", m => m.Groups[1].Value);
            MethodBody = MethodBody.TrimEnd('\n', '\r', ' ');
            MethodBody = HlslKnownKeywords.GetMappedText(MethodBody);
        }
예제 #6
0
        /// <summary>
        /// Loads a specified <see cref="ReadableMember"/> and adds it to the shader model
        /// </summary>
        /// <param name="memberInfo">The target <see cref="ReadableMember"/> to load</param>
        /// <param name="name">The optional explicit name to use for the field</param>
        /// <param name="parents">The list of parent fields to reach the current <see cref="ReadableMember"/> from a given <see cref="Action{T}"/></param>
        private void LoadFieldInfo(ReadableMember memberInfo, string?name = null, IReadOnlyList <ReadableMember>?parents = null)
        {
            Type   fieldType = memberInfo.MemberType;
            string fieldName = HlslKnownKeywords.GetMappedName(name ?? memberInfo.Name);

            // Constant buffer
            if (HlslKnownTypes.IsConstantBufferType(fieldType))
            {
                DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.ConstantBufferView, 1, _ConstantBuffersCount));

                // Track the buffer field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType.GenericTypeArguments[0]);
                _BuffersList.Add(new ConstantBufferFieldInfo(fieldType, typeName, fieldName, _ConstantBuffersCount++));
            }
            else if (HlslKnownTypes.IsReadOnlyBufferType(fieldType))
            {
                // Root parameter for a readonly buffer
                DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.ShaderResourceView, 1, _ReadOnlyBuffersCount));

                // Track the buffer field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType);
                _BuffersList.Add(new ReadOnlyBufferFieldInfo(fieldType, typeName, fieldName, _ReadOnlyBuffersCount++));
            }
            else if (HlslKnownTypes.IsReadWriteBufferType(fieldType))
            {
                // Root parameter for a read write buffer
                DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.UnorderedAccessView, 1, _ReadWriteBuffersCount));

                // Track the buffer field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType);
                _BuffersList.Add(new ReadWriteBufferFieldInfo(fieldType, typeName, fieldName, _ReadWriteBuffersCount++));
            }
            else if (HlslKnownTypes.IsKnownScalarType(fieldType) || HlslKnownTypes.IsKnownVectorType(fieldType))
            {
                // Register the captured field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType);
                _FieldsList.Add(new CapturedFieldInfo(fieldType, typeName, fieldName));
            }
            else if (fieldType.IsClass && fieldName.StartsWith("CS$<>"))
            {
                // Captured scope, update the parents list
                List <ReadableMember> updatedParents = parents?.ToList() ?? new List <ReadableMember>();
                updatedParents.Add(memberInfo);

                // Recurse on the new compiler generated class
                IReadOnlyList <FieldInfo> fields = fieldType.GetFields().ToArray();
                foreach (FieldInfo fieldInfo in fields)
                {
                    LoadFieldInfo(fieldInfo, null, updatedParents);
                }
            }
            else if (fieldType.IsDelegate() &&
                     memberInfo.GetValue(Action.Target) is Delegate func &&
                     (func.Method.IsStatic || func.Method.DeclaringType.IsStatelessDelegateContainer()) &&
                     (HlslKnownTypes.IsKnownScalarType(func.Method.ReturnType) || HlslKnownTypes.IsKnownVectorType(func.Method.ReturnType)) &&
                     fieldType.GenericTypeArguments.All(type => HlslKnownTypes.IsKnownScalarType(type) ||
                                                        HlslKnownTypes.IsKnownVectorType(type)))
            {
                // Captured static delegates with a return type
                LoadStaticMethodSource(fieldName, func.Method);
            }
        }