private string RefactorInlineOutDeclarations(string source, string entryPoint)
        {
            // Replace the discards
            source = Regex.Replace(source, @"(?<!\w)out ([\w.]+) _(?!_)", m => $"out {m.Groups[1].Value} {new string(Guid.NewGuid().ToByteArray().Select(b => (char)('a' + b % 26)).ToArray())}");

            // Load the syntax tree and the entry node
            SyntaxTree syntaxTree            = CSharpSyntaxTree.ParseText(source);
            MethodDeclarationSyntax rootNode = syntaxTree.GetRoot().DescendantNodes().OfType <MethodDeclarationSyntax>().First(node => node.GetLeadingTrivia().ToFullString().Contains(entryPoint));

            // Get the out declarations to replace
            var outs = (
                from argument in rootNode.DescendantNodes().OfType <ArgumentSyntax>()
                where argument.RefKindKeyword.IsKind(SyntaxKind.OutKeyword) &&
                argument.Expression.IsKind(SyntaxKind.DeclarationExpression)
                let match = Regex.Match(argument.Expression.ToFullString(), @"([\w.]+) ([\w_]+)")
                            let mappedType = HlslKnownTypes.GetMappedName(match.Groups[1].Value)
                                             let declaration = $"{mappedType} {match.Groups[2].Value} = ({mappedType})0;"
                                                               select declaration).ToArray();

            // Insert the explicit declarations at the start of the method
            int start = rootNode.Body.ChildNodes().First().SpanStart;

            foreach (var declaration in outs.Reverse())
            {
                source = source.Insert(start, $"{declaration}{Environment.NewLine}        ");
            }

            // Remove the out keyword from the source
            source = Regex.Replace(source, @"(?<!\w)out [\w.]+ (?=[\w_]+)", string.Empty); // Inline out declarations
            source = Regex.Replace(source, @"(?<!\w)out ", string.Empty);                  // Leftovers out keywords

            return(source);
        }
Beispiel #2
0
        /// <summary>
        /// Gets a sequence of shader static fields and their mapped names.
        /// </summary>
        /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param>
        /// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
        /// <param name="structDeclaration">The <see cref="StructDeclarationSyntax"/> instance for the current type.</param>
        /// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
        /// <param name="discoveredTypes">The collection of currently discovered types.</param>
        /// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
        /// <returns>A sequence of static constant fields in <paramref name="structDeclarationSymbol"/>.</returns>
        private static ImmutableArray <(string Name, string TypeDeclaration, string?Assignment)> GetStaticFields(
            ImmutableArray <Diagnostic> .Builder diagnostics,
            SemanticModelProvider semanticModel,
            StructDeclarationSyntax structDeclaration,
            INamedTypeSymbol structDeclarationSymbol,
            ICollection <INamedTypeSymbol> discoveredTypes,
            IDictionary <IFieldSymbol, string> constantDefinitions)
        {
            ImmutableArray <(string, string, string?)> .Builder builder = ImmutableArray.CreateBuilder <(string, string, string?)>();

            foreach (var fieldDeclaration in structDeclaration.Members.OfType <FieldDeclarationSyntax>())
            {
                foreach (var variableDeclarator in fieldDeclaration.Declaration.Variables)
                {
                    IFieldSymbol fieldSymbol = (IFieldSymbol)semanticModel.For(variableDeclarator).GetDeclaredSymbol(variableDeclarator) !;

                    if (!fieldSymbol.IsStatic || fieldSymbol.IsConst)
                    {
                        continue;
                    }

                    // Constant properties must be of a primitive, vector or matrix type
                    if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol ||
                        !HlslKnownTypes.IsKnownHlslType(typeSymbol.GetFullMetadataName()))
                    {
                        diagnostics.Add(InvalidShaderStaticFieldType, variableDeclarator, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type);

                        continue;
                    }

                    _ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string?mapping);

                    string typeDeclaration = fieldSymbol.IsReadOnly switch
                    {
                        true => $"static const {HlslKnownTypes.GetMappedName(typeSymbol)}",
                        false => $"static {HlslKnownTypes.GetMappedName(typeSymbol)}"
                    };

                    StaticFieldRewriter staticFieldRewriter = new(
                        semanticModel,
                        discoveredTypes,
                        constantDefinitions,
                        diagnostics);

                    string?assignment = staticFieldRewriter.Visit(variableDeclarator)?.NormalizeWhitespace(eol: "\n").ToFullString();

                    builder.Add((mapping ?? fieldSymbol.Name, typeDeclaration, assignment));
                }
            }

            return(builder.ToImmutable());
        }
        public static TRoot ReplaceType <TRoot>(this TRoot node, TypeSyntax type) where TRoot : SyntaxNode
        {
            string value = HlslKnownTypes.GetMappedName(type.ToString());

            // If the HLSL mapped full type name equals the original type, just return the input node
            if (value == type.ToString())
            {
                return(node);
            }

            // Process and return the type name
            TypeSyntax newType = SyntaxFactory.ParseTypeName(value).WithLeadingTrivia(type.GetLeadingTrivia()).WithTrailingTrivia(type.GetTrailingTrivia());

            return(node.ReplaceNode(type, newType));
        }
Beispiel #4
0
        /// <summary>
        /// Gets a sequence of captured fields and their mapped names.
        /// </summary>
        /// <param name="diagnostics">The collection of produced <see cref="Diagnostic"/> instances.</param>
        /// <param name="structDeclarationSymbol">The input <see cref="INamedTypeSymbol"/> instance to process.</param>
        /// <param name="types">The collection of currently discovered types.</param>
        /// <returns>A sequence of captured fields in <paramref name="structDeclarationSymbol"/>.</returns>
        private static ImmutableArray <(string Name, string HlslType)> GetInstanceFields(
            ImmutableArray <Diagnostic> .Builder diagnostics,
            INamedTypeSymbol structDeclarationSymbol,
            ICollection <INamedTypeSymbol> types)
        {
            ImmutableArray <(string, string)> .Builder values = ImmutableArray.CreateBuilder <(string, string)>();

            foreach (var fieldSymbol in structDeclarationSymbol.GetMembers().OfType <IFieldSymbol>())
            {
                if (fieldSymbol.IsStatic)
                {
                    continue;
                }

                // Captured fields must be named type symbols
                if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol)
                {
                    diagnostics.Add(InvalidShaderField, fieldSymbol, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type);

                    continue;
                }

                string metadataName = typeSymbol.GetFullMetadataName();
                string typeName     = HlslKnownTypes.GetMappedName(typeSymbol);

                _ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string?mapping);

                // Allowed fields must be unmanaged values
                if (typeSymbol.IsUnmanagedType)
                {
                    // Track the type if it's a custom struct
                    if (!HlslKnownTypes.IsKnownHlslType(metadataName))
                    {
                        types.Add(typeSymbol);
                    }

                    values.Add((mapping ?? fieldSymbol.Name, typeName));
                }
                else
                {
                    diagnostics.Add(InvalidShaderField, fieldSymbol, structDeclarationSymbol, fieldSymbol.Name, typeSymbol);
                }
            }

            return(values.ToImmutable());
        }
        /// <inheritdoc/>
        public override SyntaxNode VisitObjectCreationExpression(ObjectCreationExpressionSyntax node)
        {
            var updatedNode = (ObjectCreationExpressionSyntax)base.VisitObjectCreationExpression(node) !;

            if (SemanticModel.For(node).GetTypeInfo(node).Type is ITypeSymbol {
                IsUnmanagedType : false
            } type)
            {
                Context.ReportDiagnostic(InvalidObjectCreationExpression, node, type);
            }

            updatedNode = updatedNode.ReplaceAndTrackType(updatedNode.Type, node, SemanticModel.For(node), DiscoveredTypes);

            // New objects use the default HLSL cast syntax, eg. (float4)0
            if (updatedNode.ArgumentList !.Arguments.Count == 0)
            {
                return(CastExpression(updatedNode.Type, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))));
            }

            // Add explicit casts for matrix constructors to help the overload resolution
            if (SemanticModel.For(node).GetTypeInfo(node).Type is ITypeSymbol matrixType &&
                HlslKnownTypes.IsMatrixType(matrixType.GetFullMetadataName()))
            {
                for (int i = 0; i < node.ArgumentList !.Arguments.Count; i++)
                {
                    IArgumentOperation argumentOperation = (IArgumentOperation)SemanticModel.For(node).GetOperation(node.ArgumentList.Arguments[i]) !;
                    INamedTypeSymbol   elementType       = (INamedTypeSymbol)argumentOperation.Parameter !.Type;

                    updatedNode = updatedNode.ReplaceNode(
                        updatedNode.ArgumentList !.Arguments[i].Expression,
                        CastExpression(IdentifierName(HlslKnownTypes.GetMappedName(elementType)), updatedNode.ArgumentList.Arguments[i].Expression));
                }
            }

            return(InvocationExpression(updatedNode.Type, updatedNode.ArgumentList !));
        }
        /// <inheritdoc/>
        public override SyntaxNode VisitImplicitObjectCreationExpression(ImplicitObjectCreationExpressionSyntax node)
        {
            var updatedNode = (ImplicitObjectCreationExpressionSyntax)base.VisitImplicitObjectCreationExpression(node) !;

            if (SemanticModel.For(node).GetTypeInfo(node).Type is ITypeSymbol {
                IsUnmanagedType : false
            } type)
            {
                Context.ReportDiagnostic(InvalidObjectCreationExpression, node, type);
            }

            TypeSyntax explicitType = IdentifierName("").ReplaceAndTrackType(node, SemanticModel.For(node), DiscoveredTypes);

            // Mutate the syntax like with explicit object creation expressions
            if (updatedNode.ArgumentList !.Arguments.Count == 0)
            {
                return(CastExpression(explicitType, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))));
            }

            // Add explicit casts like with the explicit object creation expressions above
            if (SemanticModel.For(node).GetTypeInfo(node).Type is ITypeSymbol matrixType &&
                HlslKnownTypes.IsMatrixType(matrixType.GetFullMetadataName()))
            {
                for (int i = 0; i < node.ArgumentList.Arguments.Count; i++)
                {
                    IArgumentOperation argumentOperation = (IArgumentOperation)SemanticModel.For(node).GetOperation(node.ArgumentList.Arguments[i]) !;
                    INamedTypeSymbol   elementType       = (INamedTypeSymbol)argumentOperation.Parameter !.Type;

                    updatedNode = updatedNode.ReplaceNode(
                        updatedNode.ArgumentList.Arguments[i].Expression,
                        CastExpression(IdentifierName(HlslKnownTypes.GetMappedName(elementType)), updatedNode.ArgumentList.Arguments[i].Expression));
                }
            }

            return(InvocationExpression(explicitType, updatedNode.ArgumentList));
        }
        /// <summary>
        /// Loads a specified <see cref="ReadableMember"/> and adds it to the shader model
        /// </summary>
        /// <param name="memberInfo">The target <see cref="ReadableMember"/> to load</param>
        /// <param name="name">The optional explicit name to use for the field</param>
        /// <param name="parents">The list of parent fields to reach the current <see cref="ReadableMember"/> from a given <see cref="Action{T}"/></param>
        private void LoadFieldInfo(ReadableMember memberInfo, string?name = null, IReadOnlyList <ReadableMember>?parents = null)
        {
            Type   fieldType = memberInfo.MemberType;
            string fieldName = HlslKnownKeywords.GetMappedName(name ?? memberInfo.Name);

            // Constant buffer
            if (HlslKnownTypes.IsConstantBufferType(fieldType))
            {
                DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.ConstantBufferView, 1, _ConstantBuffersCount));

                // Track the buffer field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType.GenericTypeArguments[0]);
                _BuffersList.Add(new ConstantBufferFieldInfo(fieldType, typeName, fieldName, _ConstantBuffersCount++));
            }
            else if (HlslKnownTypes.IsReadOnlyBufferType(fieldType))
            {
                // Root parameter for a readonly buffer
                DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.ShaderResourceView, 1, _ReadOnlyBuffersCount));

                // Track the buffer field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType);
                _BuffersList.Add(new ReadOnlyBufferFieldInfo(fieldType, typeName, fieldName, _ReadOnlyBuffersCount++));
            }
            else if (HlslKnownTypes.IsReadWriteBufferType(fieldType))
            {
                // Root parameter for a read write buffer
                DescriptorRanges.Add(new DescriptorRange1(DescriptorRangeType.UnorderedAccessView, 1, _ReadWriteBuffersCount));

                // Track the buffer field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType);
                _BuffersList.Add(new ReadWriteBufferFieldInfo(fieldType, typeName, fieldName, _ReadWriteBuffersCount++));
            }
            else if (HlslKnownTypes.IsKnownScalarType(fieldType) || HlslKnownTypes.IsKnownVectorType(fieldType))
            {
                // Register the captured field
                memberInfo.Parents = parents;
                _CapturedMembers.Add(memberInfo);

                string typeName = HlslKnownTypes.GetMappedName(fieldType);
                _FieldsList.Add(new CapturedFieldInfo(fieldType, typeName, fieldName));
            }
            else if (fieldType.IsClass && fieldName.StartsWith("CS$<>"))
            {
                // Captured scope, update the parents list
                List <ReadableMember> updatedParents = parents?.ToList() ?? new List <ReadableMember>();
                updatedParents.Add(memberInfo);

                // Recurse on the new compiler generated class
                IReadOnlyList <FieldInfo> fields = fieldType.GetFields().ToArray();
                foreach (FieldInfo fieldInfo in fields)
                {
                    LoadFieldInfo(fieldInfo, null, updatedParents);
                }
            }
            else if (fieldType.IsDelegate() &&
                     memberInfo.GetValue(Action.Target) is Delegate func &&
                     (func.Method.IsStatic || func.Method.DeclaringType.IsStatelessDelegateContainer()) &&
                     (HlslKnownTypes.IsKnownScalarType(func.Method.ReturnType) || HlslKnownTypes.IsKnownVectorType(func.Method.ReturnType)) &&
                     fieldType.GenericTypeArguments.All(type => HlslKnownTypes.IsKnownScalarType(type) ||
                                                        HlslKnownTypes.IsKnownVectorType(type)))
            {
                // Captured static delegates with a return type
                LoadStaticMethodSource(fieldName, func.Method);
            }
        }