protected void AddInput(HlslTreeNode node) { Inputs.Add(node); node.Outputs.Add(this); AssertLoopFree(); }
private DotProductContext TryGetDot4ProductGroup(HlslTreeNode node, bool allowMatrixColumn) { // 4 by 4 dot product has a pattern of: // #1 dot3(abc, xyz) + dw // #2 dw + dot3(abc, xyz) if (!(node is AddOperation addition)) { return(null); } MultiplyOperation dw; DotProductContext innerAddition = TryGetDot3ProductGroup(addition.Addend1, allowMatrixColumn); if (innerAddition == null) { innerAddition = TryGetDot3ProductGroup(addition.Addend2, allowMatrixColumn); if (innerAddition == null) { return(null); } dw = addition.Addend1 as MultiplyOperation; if (dw == null) { return(null); } } else { dw = addition.Addend2 as MultiplyOperation; if (dw == null) { return(null); } } HlslTreeNode a = innerAddition.Value1[0]; HlslTreeNode b = innerAddition.Value1[1]; HlslTreeNode c = innerAddition.Value1[2]; HlslTreeNode d = dw.Factor1; HlslTreeNode x = innerAddition.Value2[0]; HlslTreeNode y = innerAddition.Value2[1]; HlslTreeNode z = innerAddition.Value2[2]; HlslTreeNode w = dw.Factor2; if (CanGroupComponents(c, d, allowMatrixColumn)) { if (allowMatrixColumn && SharesMatrixColumnOrRow(c, d)) { // If one of the arguments is a matrix, allow the other argument to be arbitrary. return(new DotProductContext(new[] { a, b, c, d }, new[] { x, y, z, w })); } if (CanGroupComponents(z, w, allowMatrixColumn)) { return(new DotProductContext(new[] { a, b, c, d }, new[] { x, y, z, w })); } } return(null); }
private bool CanGroupComponents(HlslTreeNode a, HlslTreeNode b, bool allowMatrixColumn) { return(_nodeGrouper.CanGroupComponents(a, b, allowMatrixColumn)); }
private bool SharesMatrixColumnOrRow(HlslTreeNode a, HlslTreeNode b) { return(a is RegisterInputNode ar && b is RegisterInputNode br && _nodeGrouper.SharesMatrixColumnOrRow(ar, br)); }
public MatrixMultiplicationContext TryGetMultiplicationGroup(IList <HlslTreeNode> components) { const bool allowMatrix = true; var first = components[0]; var firstDotProductNode = _nodeGrouper.DotProductGrouper.TryGetDotProductGroup(first, allowMatrix); if (firstDotProductNode == null) { return(null); } int dimension = firstDotProductNode.Dimension; if (components.Count < dimension) { return(null); } HlslTreeNode[] firstMatrixRow = TryGetMatrixRow(firstDotProductNode); if (firstMatrixRow == null) { return(null); } HlslTreeNode[] vector = firstDotProductNode.Value1 == firstMatrixRow ? firstDotProductNode.Value2 : firstDotProductNode.Value1; var matrixRows = new HlslTreeNode[dimension][]; matrixRows[0] = firstMatrixRow; for (int i = 1; i < dimension; i++) { var next = components[i]; var dotProductNode = _nodeGrouper.DotProductGrouper.TryGetDotProductGroup(next, dimension, allowMatrix); if (dotProductNode == null) { return(null); } HlslTreeNode[] matrixRow = TryGetMatrixRow(dotProductNode); if (matrixRow == null) { return(null); } matrixRows[i] = matrixRow; HlslTreeNode[] nextVector = dotProductNode.Value1 == matrixRow ? dotProductNode.Value2 : dotProductNode.Value1; if (NodeGrouper.IsVectorEquivalent(vector, nextVector) == false) { return(null); } } ConstantDeclaration matrix = TryGetMatrixDeclaration(matrixRows); if (matrix == null) { return(null); } bool matrixByVector = firstMatrixRow .Cast <RegisterInputNode>() .All(row => row.ComponentIndex == 0); SwizzleVector(vector, firstMatrixRow, matrixByVector); return(new MatrixMultiplicationContext(vector, matrix, matrixByVector)); }