private DotProductContext TryGetDot3ProductGroup(HlslTreeNode node, bool allowMatrixColumn)
        {
            // 3 by 3 dot product has a pattern of:
            // #1  dot(ab, xy) + c*z
            // #2  c*z + dot(ab, xy)

            if (!(node is AddOperation addition))
            {
                return(null);
            }

            MultiplyOperation cz;
            DotProductContext innerAddition = TryGetDot2ProductGroup(addition.Addend1, allowMatrixColumn);

            if (innerAddition == null)
            {
                innerAddition = TryGetDot2ProductGroup(addition.Addend2, allowMatrixColumn);
                if (innerAddition == null)
                {
                    return(null);
                }

                cz = addition.Addend1 as MultiplyOperation;
                if (cz == null)
                {
                    return(null);
                }
            }
            else
            {
                cz = addition.Addend2 as MultiplyOperation;
                if (cz == null)
                {
                    return(null);
                }
            }

            HlslTreeNode a = innerAddition.Value1[0];
            HlslTreeNode b = innerAddition.Value1[1];
            HlslTreeNode c = cz.Factor1;
            HlslTreeNode x = innerAddition.Value2[0];
            HlslTreeNode y = innerAddition.Value2[1];
            HlslTreeNode z = cz.Factor2;

            if (CanGroupComponents(b, c, allowMatrixColumn))
            {
                if (allowMatrixColumn && SharesMatrixColumnOrRow(a, b))
                {
                    // If one of the arguments is a matrix, allow the other argument to be arbitrary.
                    return(new DotProductContext(new[] { a, b, c }, new[] { x, y, z }));
                }
                if (CanGroupComponents(y, z, allowMatrixColumn))
                {
                    return(new DotProductContext(new[] { a, b, c }, new[] { x, y, z }));
                }
            }

            return(null);
        }
        private HlslTreeNode[] TryGetMatrixRow(DotProductContext firstDotProductNode)
        {
            if (firstDotProductNode.Value1[0] is RegisterInputNode value1)
            {
                ConstantDeclaration constant = _registers.FindConstant(value1);
                if (constant != null && constant.Rows > 1)
                {
                    return(firstDotProductNode.Value1);
                }
            }

            if (firstDotProductNode.Value2[0] is RegisterInputNode value2)
            {
                ConstantDeclaration constant = _registers.FindConstant(value2);
                if (constant != null && constant.Rows > 1)
                {
                    return(firstDotProductNode.Value2);
                }
            }

            return(null);
        }
Beispiel #3
0
        public string Compile(List <HlslTreeNode> components, int promoteToVectorSize = PromoteToAnyVectorSize)
        {
            if (components.Count == 0)
            {
                throw new ArgumentOutOfRangeException(nameof(components));
            }

            if (components.Count == 1)
            {
                HlslTreeNode   singleComponent = components[0];
                HlslTreeNode[] vector          = _nodeGrouper.LengthGrouper.TryGetLengthContext(singleComponent);
                if (vector != null)
                {
                    string value = Compile(vector);
                    return($"length({value})");
                }

                DotProductContext dotProduct = _nodeGrouper.DotProductGrouper.TryGetDotProductGroup(singleComponent);
                if (dotProduct != null)
                {
                    string value1 = Compile(dotProduct.Value1);
                    string value2 = Compile(dotProduct.Value2);
                    return($"dot({value1}, {value2})");
                }
            }
            else
            {
                IList <IList <HlslTreeNode> > componentGroups = _nodeGrouper.GroupComponents(components);
                if (componentGroups.Count > 1)
                {
                    return(CompileVectorConstructor(components, componentGroups));
                }

                var multiplication = _nodeGrouper.MatrixMultiplicationGrouper.TryGetMultiplicationGroup(components);
                if (multiplication != null)
                {
                    return(_matrixMultiplicationCompiler.Compile(multiplication));
                }

                var normalize = _nodeGrouper.NormalizeGrouper.TryGetContext(components);
                if (normalize != null)
                {
                    var vector = Compile(normalize);
                    return($"normalize({vector})");
                }
            }

            var first = components[0];

            if (first is ConstantNode constant)
            {
                return(CompileConstant(components, promoteToVectorSize));
            }

            if (first is Operation operation)
            {
                return(CompileOperation(operation, components, promoteToVectorSize));
            }

            if (first is IHasComponentIndex component)
            {
                return(CompileNodesWithComponents(components, first, promoteToVectorSize));
            }

            throw new NotImplementedException();
        }
        private DotProductContext TryGetDot4ProductGroup(HlslTreeNode node, bool allowMatrixColumn)
        {
            // 4 by 4 dot product has a pattern of:
            // #1  dot3(abc, xyz) + dw
            // #2  dw + dot3(abc, xyz)

            if (!(node is AddOperation addition))
            {
                return(null);
            }

            MultiplyOperation dw;
            DotProductContext innerAddition = TryGetDot3ProductGroup(addition.Addend1, allowMatrixColumn);

            if (innerAddition == null)
            {
                innerAddition = TryGetDot3ProductGroup(addition.Addend2, allowMatrixColumn);
                if (innerAddition == null)
                {
                    return(null);
                }

                dw = addition.Addend1 as MultiplyOperation;
                if (dw == null)
                {
                    return(null);
                }
            }
            else
            {
                dw = addition.Addend2 as MultiplyOperation;
                if (dw == null)
                {
                    return(null);
                }
            }

            HlslTreeNode a = innerAddition.Value1[0];
            HlslTreeNode b = innerAddition.Value1[1];
            HlslTreeNode c = innerAddition.Value1[2];
            HlslTreeNode d = dw.Factor1;
            HlslTreeNode x = innerAddition.Value2[0];
            HlslTreeNode y = innerAddition.Value2[1];
            HlslTreeNode z = innerAddition.Value2[2];
            HlslTreeNode w = dw.Factor2;

            if (CanGroupComponents(c, d, allowMatrixColumn))
            {
                if (allowMatrixColumn && SharesMatrixColumnOrRow(c, d))
                {
                    // If one of the arguments is a matrix, allow the other argument to be arbitrary.
                    return(new DotProductContext(new[] { a, b, c, d }, new[] { x, y, z, w }));
                }
                if (CanGroupComponents(z, w, allowMatrixColumn))
                {
                    return(new DotProductContext(new[] { a, b, c, d }, new[] { x, y, z, w }));
                }
            }

            return(null);
        }