Example #1
0
 protected void AddInput(HlslTreeNode node)
 {
     Inputs.Add(node);
     node.Outputs.Add(this);
     AssertLoopFree();
 }
        private DotProductContext TryGetDot4ProductGroup(HlslTreeNode node, bool allowMatrixColumn)
        {
            // 4 by 4 dot product has a pattern of:
            // #1  dot3(abc, xyz) + dw
            // #2  dw + dot3(abc, xyz)

            if (!(node is AddOperation addition))
            {
                return(null);
            }

            MultiplyOperation dw;
            DotProductContext innerAddition = TryGetDot3ProductGroup(addition.Addend1, allowMatrixColumn);

            if (innerAddition == null)
            {
                innerAddition = TryGetDot3ProductGroup(addition.Addend2, allowMatrixColumn);
                if (innerAddition == null)
                {
                    return(null);
                }

                dw = addition.Addend1 as MultiplyOperation;
                if (dw == null)
                {
                    return(null);
                }
            }
            else
            {
                dw = addition.Addend2 as MultiplyOperation;
                if (dw == null)
                {
                    return(null);
                }
            }

            HlslTreeNode a = innerAddition.Value1[0];
            HlslTreeNode b = innerAddition.Value1[1];
            HlslTreeNode c = innerAddition.Value1[2];
            HlslTreeNode d = dw.Factor1;
            HlslTreeNode x = innerAddition.Value2[0];
            HlslTreeNode y = innerAddition.Value2[1];
            HlslTreeNode z = innerAddition.Value2[2];
            HlslTreeNode w = dw.Factor2;

            if (CanGroupComponents(c, d, allowMatrixColumn))
            {
                if (allowMatrixColumn && SharesMatrixColumnOrRow(c, d))
                {
                    // If one of the arguments is a matrix, allow the other argument to be arbitrary.
                    return(new DotProductContext(new[] { a, b, c, d }, new[] { x, y, z, w }));
                }
                if (CanGroupComponents(z, w, allowMatrixColumn))
                {
                    return(new DotProductContext(new[] { a, b, c, d }, new[] { x, y, z, w }));
                }
            }

            return(null);
        }
 private bool CanGroupComponents(HlslTreeNode a, HlslTreeNode b, bool allowMatrixColumn)
 {
     return(_nodeGrouper.CanGroupComponents(a, b, allowMatrixColumn));
 }
 private bool SharesMatrixColumnOrRow(HlslTreeNode a, HlslTreeNode b)
 {
     return(a is RegisterInputNode ar &&
            b is RegisterInputNode br &&
            _nodeGrouper.SharesMatrixColumnOrRow(ar, br));
 }
        public MatrixMultiplicationContext TryGetMultiplicationGroup(IList <HlslTreeNode> components)
        {
            const bool allowMatrix = true;

            var first = components[0];
            var firstDotProductNode = _nodeGrouper.DotProductGrouper.TryGetDotProductGroup(first, allowMatrix);

            if (firstDotProductNode == null)
            {
                return(null);
            }

            int dimension = firstDotProductNode.Dimension;

            if (components.Count < dimension)
            {
                return(null);
            }

            HlslTreeNode[] firstMatrixRow = TryGetMatrixRow(firstDotProductNode);
            if (firstMatrixRow == null)
            {
                return(null);
            }

            HlslTreeNode[] vector = firstDotProductNode.Value1 == firstMatrixRow
                                        ? firstDotProductNode.Value2
                                        : firstDotProductNode.Value1;

            var matrixRows = new HlslTreeNode[dimension][];

            matrixRows[0] = firstMatrixRow;
            for (int i = 1; i < dimension; i++)
            {
                var next           = components[i];
                var dotProductNode = _nodeGrouper.DotProductGrouper.TryGetDotProductGroup(next, dimension, allowMatrix);
                if (dotProductNode == null)
                {
                    return(null);
                }

                HlslTreeNode[] matrixRow = TryGetMatrixRow(dotProductNode);
                if (matrixRow == null)
                {
                    return(null);
                }
                matrixRows[i] = matrixRow;

                HlslTreeNode[] nextVector = dotProductNode.Value1 == matrixRow
                                        ? dotProductNode.Value2
                                        : dotProductNode.Value1;
                if (NodeGrouper.IsVectorEquivalent(vector, nextVector) == false)
                {
                    return(null);
                }
            }

            ConstantDeclaration matrix = TryGetMatrixDeclaration(matrixRows);

            if (matrix == null)
            {
                return(null);
            }

            bool matrixByVector = firstMatrixRow
                                  .Cast <RegisterInputNode>()
                                  .All(row => row.ComponentIndex == 0);

            SwizzleVector(vector, firstMatrixRow, matrixByVector);

            return(new MatrixMultiplicationContext(vector, matrix, matrixByVector));
        }