public ShanqExpressionVisitor(SpirvFile file) { var visitMethods = this.GetType() .GetMethods(BindingFlags.NonPublic | BindingFlags.Instance) .Where(method => method.GetCustomAttribute <NodeTypeAttribute>() != null); foreach (var method in visitMethods) { var attribute = method.GetCustomAttribute <NodeTypeAttribute>(); this.expressionVisitors.Add(attribute.NodeType, x => (ResultId)method.Invoke(this, new object[] { x })); } this.file = file; }
public IEnumerable <T> ExecuteCollection <T>(QueryModel queryModel) { var file = new SpirvFile(); file.AddHeaderStatement(Op.OpCapability, Capability.Shader); file.AddHeaderStatement(file.GetNextResultId(), Op.OpExtInstImport, "GLSL.std.450"); file.AddHeaderStatement(Op.OpMemoryModel, AddressingModel.Logical, MemoryModel.GLSL450); var expressionVisitor = new ShanqExpressionVisitor(file); ResultId voidId = expressionVisitor.Visit(Expression.Constant(typeof(void))); ResultId actionId = expressionVisitor.Visit(Expression.Constant(typeof(Action))); ResultId entryPointerFunctionId = file.GetNextResultId(); ResultId entryPointerLabelId = file.GetNextResultId(); file.AddFunctionStatement(entryPointerFunctionId, Op.OpFunction, voidId, FunctionControl.None, actionId); file.AddFunctionStatement(entryPointerLabelId, Op.OpLabel); var fieldMapping = new Dictionary <FieldInfo, ResultId>(); var bindingMapping = new Dictionary <FieldInfo, Tuple <ResultId, int> >(); var builtinList = new Dictionary <FieldInfo, Tuple <ResultId, ResultId, int> >(); bool hasBuiltInOutput = false; var resultType = typeof(T); foreach (var field in resultType.GetFields()) { if (field.GetCustomAttribute <LocationAttribute>() != null) { var pointerType = typeof(OutputPointer <>).MakeGenericType(field.FieldType); ResultId outputPointerId = expressionVisitor.Visit(Expression.Constant(pointerType)); ResultId outputVariableId = file.GetNextResultId(); file.AddGlobalStatement(outputVariableId, Op.OpVariable, outputPointerId, StorageClass.Output); fieldMapping.Add(field, outputVariableId); } hasBuiltInOutput |= field.GetCustomAttribute <BuiltInAttribute>() != null; } var fromClauses = new FromClauseBase[] { queryModel.MainFromClause } .Concat(queryModel.BodyClauses.OfType <AdditionalFromClause>()); var inputTypes = new List <Type>(); var bindingTypes = new List <Type>(); foreach (var clause in fromClauses) { var queryable = (IShanqQueryable)((ConstantExpression)clause.FromExpression).Value; switch (queryable.Origin) { case QueryableOrigin.Input: inputTypes.Add(clause.ItemType); break; case QueryableOrigin.Binding: bindingTypes.Add(clause.ItemType); break; } } foreach (var field in inputTypes.SelectMany(type => type.GetFields())) { if (field.GetCustomAttribute <LocationAttribute>() != null) { var pointerType = typeof(InputPointer <>).MakeGenericType(field.FieldType); ResultId inputPointerId = expressionVisitor.Visit(Expression.Constant(pointerType)); ResultId inputVariableId = file.GetNextResultId(); file.AddGlobalStatement(inputVariableId, Op.OpVariable, inputPointerId, StorageClass.Input); fieldMapping.Add(field, inputVariableId); expressionVisitor.AddInputMapping(field, inputVariableId); } } foreach (var type in bindingTypes) { ResultId structureTypeId = expressionVisitor.Visit(Expression.Constant(type)); var pointerType = typeof(InputPointer <>).MakeGenericType(type); ResultId uniformPointerId = expressionVisitor.Visit(Expression.Constant(pointerType)); ResultId uniformVariableId = file.GetNextResultId(); file.AddGlobalStatement(uniformVariableId, Op.OpVariable, uniformPointerId, StorageClass.Uniform); file.AddAnnotationStatement(Op.OpDecorate, structureTypeId, Decoration.Block); file.AddAnnotationStatement(Op.OpDecorate, uniformVariableId, Decoration.DescriptorSet, 0); file.AddAnnotationStatement(Op.OpDecorate, uniformVariableId, Decoration.Binding, 0); int fieldIndex = 0; foreach (var field in type.GetFields()) { expressionVisitor.AddBinding(field, Tuple.Create(uniformVariableId, fieldIndex)); if (ShanqExpressionVisitor.IsMatrixType(field.FieldType)) { //HACK Should adapt to different matrix formats file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, fieldIndex, Decoration.ColMajor); file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, fieldIndex, Decoration.Offset, Marshal.OffsetOf(type, field.Name).ToInt32()); file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, fieldIndex, Decoration.MatrixStride, 16); } fieldIndex++; } } var entryPointParameters = fieldMapping.Select(x => x.Value).Distinct().ToList(); if (hasBuiltInOutput) { var builtInFields = resultType.GetFields().Select(x => new { Field = x, BuiltIn = x.GetCustomAttribute <BuiltInAttribute>()?.BuiltIn }) .Where(x => x.BuiltIn != null); var structureType = GetTupleType(builtInFields.Count()).MakeGenericType(builtInFields.Select(x => x.Field.FieldType).ToArray()); ResultId structureTypeId = expressionVisitor.Visit(Expression.Constant(structureType));; var structurePointerType = typeof(OutputPointer <>).MakeGenericType(structureType); ResultId structurePointerId = expressionVisitor.Visit(Expression.Constant(structurePointerType)); ResultId outputVariableId = file.GetNextResultId(); file.AddGlobalStatement(outputVariableId, Op.OpVariable, structurePointerId, StorageClass.Output); file.AddAnnotationStatement(Op.OpDecorate, structureTypeId, Decoration.Block); foreach (var field in builtInFields.Select((x, y) => new { Index = y, Field = x.Field, Value = x.BuiltIn.Value })) { file.AddAnnotationStatement(Op.OpMemberDecorate, structureTypeId, field.Index, Decoration.BuiltIn, field.Value); builtinList.Add(field.Field, Tuple.Create(structurePointerId, outputVariableId, field.Index)); } entryPointParameters.Add(outputVariableId); } file.AddHeaderStatement(Op.OpEntryPoint, new object[] { this.model, entryPointerFunctionId, "main" }.Concat(entryPointParameters.Cast <object>()).ToArray()); if (this.model == ExecutionModel.Fragment) { file.AddHeaderStatement(Op.OpExecutionMode, entryPointerFunctionId, ExecutionMode.OriginUpperLeft); } foreach (var mapping in fieldMapping) { if (mapping.Key.GetCustomAttribute <LocationAttribute>() != null) { var attribute = mapping.Key.GetCustomAttribute <LocationAttribute>(); file.AddAnnotationStatement(Op.OpDecorate, mapping.Value, Decoration.Location, attribute.LocationIndex); } } var selector = queryModel.SelectClause.Selector; switch (selector.NodeType) { case ExpressionType.Constant: foreach (var field in resultType.GetFields()) { var fieldValue = field.GetValue(((ConstantExpression)queryModel.SelectClause.Selector).Value); ResultId valueId = expressionVisitor.Visit(Expression.Constant(fieldValue, field.FieldType)); file.AddFunctionStatement(Op.OpStore, fieldMapping[field], valueId); } break; case ExpressionType.MemberInit: var initExpression = (MemberInitExpression)selector; foreach (var binding in initExpression.Bindings) { var fieldValue = ((MemberAssignment)binding).Expression; ResultId valueId = expressionVisitor.Visit(fieldValue); var field = (FieldInfo)binding.Member; if (fieldMapping.ContainsKey(field)) { file.AddFunctionStatement(Op.OpStore, fieldMapping[field], valueId); } else if (builtinList.ContainsKey(field)) { ResultId constantIndex = expressionVisitor.Visit(Expression.Constant(builtinList[field].Item3)); ResultId fieldId = file.GetNextResultId(); var fieldPointerType = typeof(OutputPointer <>).MakeGenericType(field.FieldType); ResultId fieldPointerTypeId = expressionVisitor.Visit(Expression.Constant(fieldPointerType)); file.AddFunctionStatement(fieldId, Op.OpAccessChain, fieldPointerTypeId, builtinList[field].Item2, constantIndex); file.AddFunctionStatement(Op.OpStore, fieldId, valueId); } } break; default: throw new NotImplementedException(); } file.AddFunctionStatement(Op.OpReturn); file.AddFunctionStatement(Op.OpFunctionEnd); int bound = file.Entries.Select(x => x.ResultId) .Where(x => x.HasValue) .Max(x => x.Value.Id) + 1; var sink = new BinarySink(this.outputStream, bound); foreach (var entry in file.Entries) { sink.AddStatement(entry.ResultId, entry.Statement); } return(Enumerable.Empty <T>()); }