예제 #1
0
 public MathIntrinsicAttribute(
     MathIntrinsicKind kind,
     ArithmeticFlags flags)
 {
     IntrinsicKind  = kind;
     IntrinsicFlags = flags;
 }
예제 #2
0
 public MathIntrinsicAttribute(
     MathIntrinsicKind intrinsicKind,
     ArithmeticFlags intrinsicFlags)
 {
     IntrinsicKind  = intrinsicKind;
     IntrinsicFlags = intrinsicFlags;
 }
예제 #3
0
        /// <summary>
        /// Creates a unary arithmetic operation.
        /// </summary>
        /// <param name="node">The operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Value node,
            UnaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            Debug.Assert(node != null, "Invalid node");

            if (UseConstantPropagation)
            {
                // Check for constants
                if (node is PrimitiveValue value)
                {
                    return(UnaryArithmeticFoldConstants(value, kind));
                }

                var isUnsigned = (flags & ArithmeticFlags.Unsigned) == ArithmeticFlags.Unsigned;
                switch (kind)
                {
                case UnaryArithmeticKind.Not:
                    if (node is UnaryArithmeticValue otherValue &&
                        otherValue.Kind == UnaryArithmeticKind.Not)
                    {
                        return(otherValue.Value);
                    }
                    if (node is CompareValue compareValue)
                    {
                        return(CreateCompare(
                                   compareValue.Left,
                                   compareValue.Right,
                                   CompareValue.Invert(compareValue.Kind),
                                   compareValue.Flags));
                    }
                    break;

                case UnaryArithmeticKind.Neg:
                    if (node.BasicValueType == BasicValueType.Int1)
                    {
                        return(CreateArithmetic(node, UnaryArithmeticKind.Not));
                    }
                    break;

                case UnaryArithmeticKind.Abs:
                    if (isUnsigned)
                    {
                        return(node);
                    }
                    break;
                }
            }

            return(Append(new UnaryArithmeticValue(
                              Context,
                              BasicBlock,
                              node,
                              kind,
                              flags)));
        }
예제 #4
0
        /// <summary>
        /// Constructs a new arithmetic value.
        /// </summary>
        /// <param name="basicBlock">The parent basic block.</param>
        /// <param name="operands">The operands.</param>
        /// <param name="flags">The operation flags.</param>
        /// <param name="initialType">The initial node type.</param>
        internal ArithmeticValue(
            BasicBlock basicBlock,
            ImmutableArray <ValueReference> operands,
            ArithmeticFlags flags,
            TypeNode initialType)
            : base(basicBlock, initialType)
        {
            Flags = flags;

            Seal(operands);
        }
예제 #5
0
 /// <summary>
 /// Constructs a new unary arithmetic operation.
 /// </summary>
 /// <param name="context">The parent IR context.</param>
 /// <param name="basicBlock">The parent basic block.</param>
 /// <param name="value">The operand.</param>
 /// <param name="kind">The operation kind.</param>
 /// <param name="flags">The operation flags.</param>
 internal UnaryArithmeticValue(
     IRContext context,
     BasicBlock basicBlock,
     ValueReference value,
     UnaryArithmeticKind kind,
     ArithmeticFlags flags)
     : base(
         basicBlock,
         ImmutableArray.Create(value),
         flags,
         ComputeType(context, value, kind))
 {
     Kind = kind;
 }
예제 #6
0
        /// <summary>
        /// Constructs a new ternary arithmetic value.
        /// </summary>
        /// <param name="basicBlock">The parent basic block.</param>
        /// <param name="first">The first operand.</param>
        /// <param name="second">The second operand.</param>
        /// <param name="third">The third operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">The operation flags.</param>
        internal TernaryArithmeticValue(
            BasicBlock basicBlock,
            ValueReference first,
            ValueReference second,
            ValueReference third,
            TernaryArithmeticKind kind,
            ArithmeticFlags flags)
            : base(
                basicBlock,
                ImmutableArray.Create(first, second, third),
                flags,
                ComputeType(first))
        {
            Debug.Assert(
                first.Type == second.Type &&
                second.Type == third.Type, "Invalid types");

            Kind = kind;
        }
예제 #7
0
        /// <summary>
        /// Constructs a new binary arithmetic value.
        /// </summary>
        /// <param name="basicBlock">The parent basic block.</param>
        /// <param name="left">The left operand.</param>
        /// <param name="right">The right operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">The operation flags.</param>
        internal BinaryArithmeticValue(
            BasicBlock basicBlock,
            ValueReference left,
            ValueReference right,
            BinaryArithmeticKind kind,
            ArithmeticFlags flags)
            : base(
                basicBlock,
                ImmutableArray.Create(left, right),
                flags,
                ComputeType(left))
        {
            Debug.Assert(
                left.Type == right.Type ||
                (kind == BinaryArithmeticKind.Shl || kind == BinaryArithmeticKind.Shr) &&
                right.BasicValueType == BasicValueType.Int32, "Invalid types");

            Kind = kind;
        }
예제 #8
0
        /// <summary>
        /// Creates a ternary arithmetic operation.
        /// </summary>
        /// <param name="location">The current location.</param>
        /// <param name="first">The first operand.</param>
        /// <param name="second">The second operand.</param>
        /// <param name="third">The second operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Location location,
            Value first,
            Value second,
            Value third,
            TernaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            if (UseConstantPropagation)
            {
                // Check for constants
                if (first is PrimitiveValue firstValue &&
                    second is PrimitiveValue secondValue)
                {
                    var value = BinaryArithmeticFoldConstants(
                        location,
                        firstValue,
                        secondValue,
                        TernaryArithmeticValue.GetLeftBinaryKind(kind),
                        flags);

                    // Try to fold right hand side as well
                    var rightOperation = TernaryArithmeticValue.GetRightBinaryKind(kind);
                    return(CreateArithmetic(
                               location,
                               value,
                               third,
                               rightOperation));
                }
            }

            return(Append(new TernaryArithmeticValue(
                              GetInitializer(location),
                              first,
                              second,
                              third,
                              kind,
                              flags)));
        }
예제 #9
0
        /// <summary>
        /// Creates a ternary arithmetic operation.
        /// </summary>
        /// <param name="first">The first operand.</param>
        /// <param name="second">The second operand.</param>
        /// <param name="third">The second operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Value first,
            Value second,
            Value third,
            TernaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            Debug.Assert(first != null, "Invalid first node");
            Debug.Assert(second != null, "Invalid second node");
            Debug.Assert(third != null, "Invalid third node");

            if (UseConstantPropagation)
            {
                // Check for constants
                if (first is PrimitiveValue firstValue &&
                    second is PrimitiveValue secondValue)
                {
                    var value = BinaryArithmeticFoldConstants(
                        firstValue,
                        secondValue,
                        TernaryArithmeticValue.GetLeftBinaryKind(kind),
                        flags);

                    // Try to fold right hand side as well
                    var rightOperation = TernaryArithmeticValue.GetRightBinaryKind(kind);
                    return(CreateArithmetic(value, third, rightOperation));
                }
            }

            return(Append(new TernaryArithmeticValue(
                              BasicBlock,
                              first,
                              second,
                              third,
                              kind,
                              flags)));
        }
예제 #10
0
파일: Arithmetic.cs 프로젝트: m4rs-mt/ILGPU
        /// <summary>
        /// Creates a unary arithmetic operation.
        /// </summary>
        /// <param name="location">The current location.</param>
        /// <param name="node">The operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Location location,
            Value node,
            UnaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            // Check for constants
            if (node is PrimitiveValue value)
            {
                return(UnaryArithmeticFoldConstants(location, value, kind));
            }

            return
                (UnaryArithmeticSimplify(
                     location,
                     node,
                     kind,
                     flags)
                 ?? Append(new UnaryArithmeticValue(
                               GetInitializer(location),
                               node,
                               kind,
                               flags)));
        }
예제 #11
0
파일: Arithmetic.cs 프로젝트: m4rs-mt/ILGPU
        /// <summary>
        /// Creates a binary arithmetic operation.
        /// </summary>
        /// <param name="location">The current location.</param>
        /// <param name="left">The left operand.</param>
        /// <param name="right">The right operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Location location,
            Value left,
            Value right,
            BinaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            VerifyBinaryArithmeticOperands(location, left, right, kind);

            Value simplified;

            if (right is PrimitiveValue rightValue)
            {
                // Check for constants
                if (left is PrimitiveValue leftPrimitive)
                {
                    return(BinaryArithmeticFoldConstants(
                               location,
                               leftPrimitive,
                               rightValue,
                               kind,
                               flags));
                }

                // Check for simplifications of the RHS
                if ((simplified = BinaryArithmeticSimplify_RHS(
                         location,
                         left,
                         rightValue,
                         kind,
                         flags)) != null)
                {
                    return(simplified);
                }

                if (left is BinaryArithmeticValue leftBinary &&
                    leftBinary.Kind == kind &&
                    leftBinary.Right.Resolve() is PrimitiveValue nestedRightValue &&
                    (simplified = BinaryArithmeticSimplify_RHS(
                         location,
                         leftBinary,
                         nestedRightValue,
                         rightValue,
                         kind,
                         flags)) != null)
                {
                    return(simplified);
                }
            }

            if (left is PrimitiveValue leftValue)
            {
                // Move constants to the right
                if (kind.IsCommutative())
                {
                    return(CreateArithmetic(
                               location,
                               right,
                               left,
                               kind,
                               flags));
                }

                // Check for simplifications of the LHS
                if ((simplified = BinaryArithmeticSimplify_LHS(
                         location,
                         leftValue,
                         right,
                         kind,
                         flags)) != null)
                {
                    return(simplified);
                }

                if (right is BinaryArithmeticValue rightBinary &&
                    rightBinary.Kind == kind &&
                    rightBinary.Left.Resolve() is PrimitiveValue nestedLeftValue &&
                    (simplified = BinaryArithmeticSimplify_LHS(
                         location,
                         rightBinary,
                         nestedLeftValue,
                         leftValue,
                         kind,
                         flags)) != null)
                {
                    return(simplified);
                }
            }

            return(Append(new BinaryArithmeticValue(
                              GetInitializer(location),
                              left,
                              right,
                              kind,
                              flags)));
        }
예제 #12
0
        /// <summary>
        /// Creates a binary arithmetic operation.
        /// </summary>
        /// <param name="left">The left operand.</param>
        /// <param name="right">The right operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Value left,
            Value right,
            BinaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            Debug.Assert(left != null, "Invalid left node");
            Debug.Assert(right != null, "Invalid right node");

            if (UseConstantPropagation && left is PrimitiveValue leftValue)
            {
                // Check for constants
                if (right is PrimitiveValue rightValue)
                {
                    return(BinaryArithmeticFoldConstants(
                               leftValue, rightValue, kind, flags));
                }

                if (kind == BinaryArithmeticKind.Div)
                {
                    switch (left.BasicValueType)
                    {
                    case BasicValueType.Float32:
                        if (leftValue.Float32Value == 1.0f)
                        {
                            return(CreateArithmetic(right, UnaryArithmeticKind.RcpF));
                        }
                        break;

                    case BasicValueType.Float64:
                        if (leftValue.Float64Value == 1.0)
                        {
                            return(CreateArithmetic(right, UnaryArithmeticKind.RcpF));
                        }
                        break;

                    default:
                        break;
                    }
                }
            }

            switch (kind)
            {
            case BinaryArithmeticKind.And:
            case BinaryArithmeticKind.Or:
            case BinaryArithmeticKind.Xor:
                if (left.BasicValueType.IsFloat())
                {
                    throw new NotSupportedException(string.Format(
                                                        ErrorMessages.NotSupportedArithmeticArgumentType,
                                                        left.BasicValueType));
                }
                break;

            case BinaryArithmeticKind.Atan2F:
            case BinaryArithmeticKind.PowF:
                if (!left.BasicValueType.IsFloat())
                {
                    throw new NotSupportedException(string.Format(
                                                        ErrorMessages.NotSupportedArithmeticArgumentType,
                                                        left.BasicValueType));
                }
                break;
            }

            return(Append(new BinaryArithmeticValue(
                              BasicBlock,
                              left,
                              right,
                              kind,
                              flags)));
        }
예제 #13
0
        /// <summary>
        /// Creates a unary arithmetic operation.
        /// </summary>
        /// <param name="location">The current location.</param>
        /// <param name="node">The operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Location location,
            Value node,
            UnaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            if (UseConstantPropagation)
            {
                // Check for constants
                if (node is PrimitiveValue value)
                {
                    return(UnaryArithmeticFoldConstants(location, value, kind));
                }

                var isUnsigned = (flags & ArithmeticFlags.Unsigned) ==
                                 ArithmeticFlags.Unsigned;
                switch (kind)
                {
                case UnaryArithmeticKind.Not:
                    if (node is UnaryArithmeticValue otherValue &&
                        otherValue.Kind == UnaryArithmeticKind.Not)
                    {
                        return(otherValue.Value);
                    }

                    if (node is CompareValue compareValue)
                    {
                        // When the comparison is inverted, and we are comparing floats,
                        // toggle between ordered/unordered float comparison.
                        var compareFlags = compareValue.Flags;
                        if (compareValue.Left.BasicValueType.IsFloat() &&
                            compareValue.Right.BasicValueType.IsFloat())
                        {
                            compareFlags ^= CompareFlags.UnsignedOrUnordered;
                        }

                        return(CreateCompare(
                                   location,
                                   compareValue.Left,
                                   compareValue.Right,
                                   CompareValue.Invert(compareValue.Kind),
                                   compareFlags));
                    }
                    break;

                case UnaryArithmeticKind.Neg:
                    if (node.BasicValueType == BasicValueType.Int1)
                    {
                        return(CreateArithmetic(
                                   location,
                                   node,
                                   UnaryArithmeticKind.Not));
                    }
                    break;

                case UnaryArithmeticKind.Abs:
                    if (isUnsigned)
                    {
                        return(node);
                    }
                    break;
                }
            }

            return(Append(new UnaryArithmeticValue(
                              GetInitializer(location),
                              node,
                              kind,
                              flags)));
        }
예제 #14
0
        /// <summary>
        /// Creates a binary arithmetic operation.
        /// </summary>
        /// <param name="location">The current location.</param>
        /// <param name="left">The left operand.</param>
        /// <param name="right">The right operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Location location,
            Value left,
            Value right,
            BinaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            // TODO: add additional partial arithmetic simplifications in a generic way
            if (UseConstantPropagation && left is PrimitiveValue leftValue)
            {
                // Check for constants
                if (right is PrimitiveValue rightConstant)
                {
                    return(BinaryArithmeticFoldConstants(
                               location,
                               leftValue,
                               rightConstant,
                               kind,
                               flags));
                }

                if (kind == BinaryArithmeticKind.Div)
                {
                    switch (left.BasicValueType)
                    {
                    case BasicValueType.Float32:
                        if (leftValue.Float32Value == 1.0f)
                        {
                            return(CreateArithmetic(
                                       location,
                                       right,
                                       UnaryArithmeticKind.RcpF));
                        }
                        break;

                    case BasicValueType.Float64:
                        if (leftValue.Float64Value == 1.0)
                        {
                            return(CreateArithmetic(
                                       location,
                                       right,
                                       UnaryArithmeticKind.RcpF));
                        }
                        break;

                    default:
                        break;
                    }
                }
            }

            // TODO: remove the following hard-coded rules
            if (right is PrimitiveValue rightValue &&
                left.BasicValueType.IsInt() &&
                Utilities.IsPowerOf2(rightValue.RawValue))
            {
                if (kind == BinaryArithmeticKind.Div ||
                    kind == BinaryArithmeticKind.Mul)
                {
                    var shiftAmount = CreatePrimitiveValue(
                        location,
                        (int)Math.Log(
                            Math.Abs((double)rightValue.RawValue),
                            2.0));
                    var leftKind = Utilities.Select(
                        kind == BinaryArithmeticKind.Div,
                        BinaryArithmeticKind.Shr,
                        BinaryArithmeticKind.Shl);
                    var rightKind = Utilities.Select(
                        leftKind == BinaryArithmeticKind.Shr,
                        BinaryArithmeticKind.Shl,
                        BinaryArithmeticKind.Shr);
                    return(CreateArithmetic(
                               location,
                               left,
                               shiftAmount,
                               Utilities.Select(
                                   rightValue.RawValue > 0,
                                   leftKind,
                                   rightKind)));
                }
            }

            switch (kind)
            {
            case BinaryArithmeticKind.And:
            case BinaryArithmeticKind.Or:
            case BinaryArithmeticKind.Xor:
                if (left.BasicValueType.IsFloat())
                {
                    throw location.GetNotSupportedException(
                              ErrorMessages.NotSupportedArithmeticArgumentType,
                              left.BasicValueType);
                }

                break;

            case BinaryArithmeticKind.Atan2F:
            case BinaryArithmeticKind.PowF:
                if (!left.BasicValueType.IsFloat())
                {
                    throw location.GetNotSupportedException(
                              ErrorMessages.NotSupportedArithmeticArgumentType,
                              left.BasicValueType);
                }

                break;
            }

            return(Append(new BinaryArithmeticValue(
                              GetInitializer(location),
                              left,
                              right,
                              kind,
                              flags)));
        }
예제 #15
0
        /// <summary>
        /// Creates a unary arithmetic operation.
        /// </summary>
        /// <param name="location">The current location.</param>
        /// <param name="node">The operand.</param>
        /// <param name="kind">The operation kind.</param>
        /// <param name="flags">Operation flags.</param>
        /// <returns>A node that represents the arithmetic operation.</returns>
        public ValueReference CreateArithmetic(
            Location location,
            Value node,
            UnaryArithmeticKind kind,
            ArithmeticFlags flags)
        {
            if (UseConstantPropagation)
            {
                // Check for constants
                if (node is PrimitiveValue value)
                {
                    return(UnaryArithmeticFoldConstants(location, value, kind));
                }

                var isUnsigned = (flags & ArithmeticFlags.Unsigned) ==
                                 ArithmeticFlags.Unsigned;
                switch (kind)
                {
                case UnaryArithmeticKind.Not:
                    switch (node)
                    {
                    // Check nested not operations
                    case UnaryArithmeticValue otherValue when
                        otherValue.Kind == UnaryArithmeticKind.Not:
                        return(otherValue.Value);

                    // Check whether we can invert compare values
                    case CompareValue compareValue:
                        // When the comparison is inverted, and we are comparing
                        // floats, toggle between ordered/unordered float
                        // comparison
                        var compareFlags = compareValue.Flags;
                        if (compareValue.Left.BasicValueType.IsFloat() &&
                            compareValue.Right.BasicValueType.IsFloat())
                        {
                            compareFlags ^= CompareFlags.UnsignedOrUnordered;
                        }

                        return(CreateCompare(
                                   location,
                                   compareValue.Left,
                                   compareValue.Right,
                                   CompareValue.Invert(compareValue.Kind),
                                   compareFlags));

                    // Propagate the not operator through binary operations
                    case BinaryArithmeticValue otherBinary when
                        BinaryArithmeticValue.TryInvertLogical(
                            otherBinary.Kind,
                            out var invertedBinary):
                        return(CreateArithmetic(
                                   otherBinary.Location,
                                   CreateArithmetic(
                                       location,
                                       otherBinary.Left,
                                       UnaryArithmeticKind.Not),
                                   CreateArithmetic(
                                       location,
                                       otherBinary.Right,
                                       UnaryArithmeticKind.Not),
                                   invertedBinary,
                                   otherBinary.Flags));
                    }
                    break;

                case UnaryArithmeticKind.Neg:
                    if (node.BasicValueType == BasicValueType.Int1)
                    {
                        return(CreateArithmetic(
                                   location,
                                   node,
                                   UnaryArithmeticKind.Not));
                    }
                    break;

                case UnaryArithmeticKind.Abs:
                    if (isUnsigned)
                    {
                        return(node);
                    }
                    break;
                }
            }

            return(Append(new UnaryArithmeticValue(
                              GetInitializer(location),
                              node,
                              kind,
                              flags)));
        }