public ShanqExpressionVisitor(SpirvFile file)
        {
            var visitMethods = this.GetType()
                               .GetMethods(BindingFlags.NonPublic | BindingFlags.Instance)
                               .Where(method => method.GetCustomAttribute <NodeTypeAttribute>() != null);

            foreach (var method in visitMethods)
            {
                var attribute = method.GetCustomAttribute <NodeTypeAttribute>();

                this.expressionVisitors.Add(attribute.NodeType, x => (ResultId)method.Invoke(this, new object[] { x }));
            }

            this.file = file;
        }
Esempio n. 2
0
        public IEnumerable <T> ExecuteCollection <T>(QueryModel queryModel)
        {
            var file = new SpirvFile();

            file.AddHeaderStatement(Op.OpCapability, Capability.Shader);
            file.AddHeaderStatement(file.GetNextResultId(), Op.OpExtInstImport, "GLSL.std.450");
            file.AddHeaderStatement(Op.OpMemoryModel, AddressingModel.Logical, MemoryModel.GLSL450);

            var expressionVisitor = new ShanqExpressionVisitor(file);

            ResultId voidId   = expressionVisitor.Visit(Expression.Constant(typeof(void)));
            ResultId actionId = expressionVisitor.Visit(Expression.Constant(typeof(Action)));

            ResultId entryPointerFunctionId = file.GetNextResultId();
            ResultId entryPointerLabelId    = file.GetNextResultId();

            file.AddFunctionStatement(entryPointerFunctionId, Op.OpFunction, voidId, FunctionControl.None, actionId);
            file.AddFunctionStatement(entryPointerLabelId, Op.OpLabel);

            var fieldMapping   = new Dictionary <FieldInfo, ResultId>();
            var bindingMapping = new Dictionary <FieldInfo, Tuple <ResultId, int> >();
            var builtinList    = new Dictionary <FieldInfo, Tuple <ResultId, ResultId, int> >();

            bool hasBuiltInOutput = false;

            var resultType = typeof(T);

            foreach (var field in resultType.GetFields())
            {
                if (field.GetCustomAttribute <LocationAttribute>() != null)
                {
                    var      pointerType      = typeof(OutputPointer <>).MakeGenericType(field.FieldType);
                    ResultId outputPointerId  = expressionVisitor.Visit(Expression.Constant(pointerType));
                    ResultId outputVariableId = file.GetNextResultId();

                    file.AddGlobalStatement(outputVariableId, Op.OpVariable, outputPointerId, StorageClass.Output);

                    fieldMapping.Add(field, outputVariableId);
                }

                hasBuiltInOutput |= field.GetCustomAttribute <BuiltInAttribute>() != null;
            }

            var fromClauses = new FromClauseBase[] { queryModel.MainFromClause }
            .Concat(queryModel.BodyClauses.OfType <AdditionalFromClause>());

            var inputTypes   = new List <Type>();
            var bindingTypes = new List <Type>();

            foreach (var clause in fromClauses)
            {
                var queryable = (IShanqQueryable)((ConstantExpression)clause.FromExpression).Value;

                switch (queryable.Origin)
                {
                case QueryableOrigin.Input:
                    inputTypes.Add(clause.ItemType);
                    break;

                case QueryableOrigin.Binding:
                    bindingTypes.Add(clause.ItemType);
                    break;
                }
            }

            foreach (var field in inputTypes.SelectMany(type => type.GetFields()))
            {
                if (field.GetCustomAttribute <LocationAttribute>() != null)
                {
                    var      pointerType     = typeof(InputPointer <>).MakeGenericType(field.FieldType);
                    ResultId inputPointerId  = expressionVisitor.Visit(Expression.Constant(pointerType));
                    ResultId inputVariableId = file.GetNextResultId();

                    file.AddGlobalStatement(inputVariableId, Op.OpVariable, inputPointerId, StorageClass.Input);

                    fieldMapping.Add(field, inputVariableId);

                    expressionVisitor.AddInputMapping(field, inputVariableId);
                }
            }

            foreach (var type in bindingTypes)
            {
                ResultId structureTypeId   = expressionVisitor.Visit(Expression.Constant(type));
                var      pointerType       = typeof(InputPointer <>).MakeGenericType(type);
                ResultId uniformPointerId  = expressionVisitor.Visit(Expression.Constant(pointerType));
                ResultId uniformVariableId = file.GetNextResultId();

                file.AddGlobalStatement(uniformVariableId, Op.OpVariable, uniformPointerId, StorageClass.Uniform);
                file.AddAnnotationStatement(Op.OpDecorate, structureTypeId, Decoration.Block);
                file.AddAnnotationStatement(Op.OpDecorate, uniformVariableId, Decoration.DescriptorSet, 0);
                file.AddAnnotationStatement(Op.OpDecorate, uniformVariableId, Decoration.Binding, 0);

                int fieldIndex = 0;

                foreach (var field in type.GetFields())
                {
                    expressionVisitor.AddBinding(field, Tuple.Create(uniformVariableId, fieldIndex));

                    if (ShanqExpressionVisitor.IsMatrixType(field.FieldType))
                    {
                        //HACK Should adapt to different matrix formats
                        file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, fieldIndex, Decoration.ColMajor);
                        file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, fieldIndex, Decoration.Offset, Marshal.OffsetOf(type, field.Name).ToInt32());
                        file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, fieldIndex, Decoration.MatrixStride, 16);
                    }

                    fieldIndex++;
                }
            }

            var entryPointParameters = fieldMapping.Select(x => x.Value).Distinct().ToList();

            if (hasBuiltInOutput)
            {
                var builtInFields = resultType.GetFields().Select(x => new { Field = x, BuiltIn = x.GetCustomAttribute <BuiltInAttribute>()?.BuiltIn })
                                    .Where(x => x.BuiltIn != null);

                var      structureType   = GetTupleType(builtInFields.Count()).MakeGenericType(builtInFields.Select(x => x.Field.FieldType).ToArray());
                ResultId structureTypeId = expressionVisitor.Visit(Expression.Constant(structureType));;

                var      structurePointerType = typeof(OutputPointer <>).MakeGenericType(structureType);
                ResultId structurePointerId   = expressionVisitor.Visit(Expression.Constant(structurePointerType));
                ResultId outputVariableId     = file.GetNextResultId();

                file.AddGlobalStatement(outputVariableId, Op.OpVariable, structurePointerId, StorageClass.Output);

                file.AddAnnotationStatement(Op.OpDecorate, structureTypeId, Decoration.Block);

                foreach (var field in builtInFields.Select((x, y) => new { Index = y, Field = x.Field, Value = x.BuiltIn.Value }))
                {
                    file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, field.Index, Decoration.BuiltIn, field.Value);
                    builtinList.Add(field.Field, Tuple.Create(structurePointerId, outputVariableId, field.Index));
                }

                entryPointParameters.Add(outputVariableId);
            }

            file.AddHeaderStatement(Op.OpEntryPoint, new object[] { this.model, entryPointerFunctionId, "main" }.Concat(entryPointParameters.Cast <object>()).ToArray());
            if (this.model == ExecutionModel.Fragment)
            {
                file.AddHeaderStatement(Op.OpExecutionMode, entryPointerFunctionId, ExecutionMode.OriginUpperLeft);
            }

            foreach (var mapping in fieldMapping)
            {
                if (mapping.Key.GetCustomAttribute <LocationAttribute>() != null)
                {
                    var attribute = mapping.Key.GetCustomAttribute <LocationAttribute>();

                    file.AddAnnotationStatement(Op.OpDecorate, mapping.Value, Decoration.Location, attribute.LocationIndex);
                }
            }

            var selector = queryModel.SelectClause.Selector;

            switch (selector.NodeType)
            {
            case ExpressionType.Constant:
                foreach (var field in resultType.GetFields())
                {
                    var fieldValue = field.GetValue(((ConstantExpression)queryModel.SelectClause.Selector).Value);

                    ResultId valueId = expressionVisitor.Visit(Expression.Constant(fieldValue, field.FieldType));

                    file.AddFunctionStatement(Op.OpStore, fieldMapping[field], valueId);
                }
                break;

            case ExpressionType.MemberInit:
                var initExpression = (MemberInitExpression)selector;

                foreach (var binding in initExpression.Bindings)
                {
                    var fieldValue = ((MemberAssignment)binding).Expression;

                    ResultId valueId = expressionVisitor.Visit(fieldValue);

                    var field = (FieldInfo)binding.Member;

                    if (fieldMapping.ContainsKey(field))
                    {
                        file.AddFunctionStatement(Op.OpStore, fieldMapping[field], valueId);
                    }
                    else if (builtinList.ContainsKey(field))
                    {
                        ResultId constantIndex = expressionVisitor.Visit(Expression.Constant(builtinList[field].Item3));
                        ResultId fieldId       = file.GetNextResultId();

                        var      fieldPointerType   = typeof(OutputPointer <>).MakeGenericType(field.FieldType);
                        ResultId fieldPointerTypeId = expressionVisitor.Visit(Expression.Constant(fieldPointerType));

                        file.AddFunctionStatement(fieldId, Op.OpAccessChain, fieldPointerTypeId, builtinList[field].Item2, constantIndex);
                        file.AddFunctionStatement(Op.OpStore, fieldId, valueId);
                    }
                }
                break;

            default:
                throw new NotImplementedException();
            }

            file.AddFunctionStatement(Op.OpReturn);
            file.AddFunctionStatement(Op.OpFunctionEnd);

            int bound = file.Entries.Select(x => x.ResultId)
                        .Where(x => x.HasValue)
                        .Max(x => x.Value.Id) + 1;

            var sink = new BinarySink(this.outputStream, bound);

            foreach (var entry in file.Entries)
            {
                sink.AddStatement(entry.ResultId, entry.Statement);
            }

            return(Enumerable.Empty <T>());
        }