diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 17c1ac14d..9c4fb395a 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -3766,6 +3766,280 @@ public: }; ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + // f1f0 = {0x00, i3i2i1i0, 0x00, i3i2i1i0} + // f3f2 = {0x00, i5i4i3i2, 0x00, i5i4i3i2} + // f5f4 = {0x00, i7i6i5i4, 0x00, i7i6i5i4} + // f7f6 = {0x00, i9i8i7i6, 0x00, i9i8i7i6} + // f9f8 = {0x00, i11i10i9i8, 0x00, i11i10i9i8} + // f11f10 = {0x00, i13i12i11i10, 0x00, i13i12i11i10} + // f13f12 = {0x00, i15i14i13i12, 0x00, i15i14i13i12} + // f15f14 = {0x00, 0000i15i14, 0x00, 0000i15i14} + // We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC + // might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii]) + : "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2])); + + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii + 1]) + : "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2])); + } + + // The below XOR does the following: + // Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing + // 1024 + x + 2, 1024 + 4 * (x + 2) + // We use lop3 so that we can use 1 instruction for AND and XOR. + // static constexpr uint32_t xor_mask[2] = { 0x64086402, 0x64806420}; + // static constexpr uint32_t and_mask[2] = { 0x000C0003, 0x00C00030}; + static constexpr uint32_t xor_mask = 0x64086402; + static constexpr uint32_t and_mask = 0x000C0003; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2] + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{ lop3.b32 %0, %0, %1, %2, %3; }\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // {-258, -1026} + static constexpr uint32_t hfma_bias_rep = 0xDC08E402; + // {1/4, 1} + static constexpr uint32_t hfma_scale_rep = 0x34003C00; + + // Scale and subtract the FP16s to get the original int4 number as FP16. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(hfma_scale_rep), + reinterpret_cast(hfma_bias_rep)); + } + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + // f1f0 = {0x00, u3u2u1u0, 0x00, u3u2u1u0} + // f3f2 = {0x00, u5u4u3u2, 0x00, u5u4u3u2} + // f5f4 = {0x00, u7u6u5u4, 0x00, u7u6u5u4} + // f7f6 = {0x00, u9u8u7u6, 0x00, u9u8u7u6} + // f9f8 = {0x00, u11u10u9u8, 0x00, u11u10u9u8} + // f11f10 = {0x00, u13u12u11u10, 0x00, u13u12u11u10} + // f13f12 = {0x00, u15u14u13u12, 0x00, u15u14u13u12} + // f15f14 = {0x00, 0000u15u14, 0x00, 0000u15u14} + // We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC + // might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> FP16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii]) + : "r"(src_reg), "n"(0), "r"(prmt_indices[ii / 2])); + + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii + 1]) + : "r"(src_reg_shifted), "n"(0), "r"(prmt_indices[ii / 2])); + } + + // The below XOR does the following: + // Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing + // 1024 + x, 1024 + 4 * x + // We use lop3 so that we can use 1 instruction for AND and OR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0x000C0003; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask[i / 2]) ^ xor_mask[i / 2] + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{ lop3.b32 %0, %0, %1, %2, %3; }\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // {-256, -1024} + static constexpr uint32_t hfma_bias_rep = 0xDC00E400; + // {1/4, 1} + static constexpr uint32_t hfma_scale_rep = 0x34003C00; + + // Scale and subtract the FP16s to get the original int4 number as FP16. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(hfma_scale_rep), + reinterpret_cast(hfma_bias_rep)); + } + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + /// Partial specialization for Array <= Array template struct NumericArrayConverter { @@ -3830,13 +4104,11 @@ private: // We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC // might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here. uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; - static_assert(RegArray::kElements <= 4, "Too many inputs for F16 -> I4 vector converter"); + static_assert(RegArray::kElements <= 4, "Too many inputs for I4 ->F16 vector converter"); CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < RegArray::kElements; ++ii) { asm volatile( - "{\n" - " prmt.b32 %0, %1, %2, %3;\n" - "}\n" + "{ prmt.b32 %0, %1, %2, %3; }\n" : "=r"(r[ii]) : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); } @@ -3891,6 +4163,133 @@ private: return reinterpret_cast(r); } + friend class detail::VectorizedConverter; +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses bit tricks to construct a known FP16 number, then does a + // subtraction in FP16 for the final result. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + // Below constructs the following temporary: + // fp16s_01 = {0x00, u4_01, 0x00, u4_01} + // fp16s_23 = {0x00, u4_23, 0x00, u4_23} + // fp16s_45 = {0x00, u4_45, 0x00, u4_45} + // fp16s_67 = {0x00, u4_67, 0x00, u4_67} + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, "Too many inputs for u4 -> f16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii]) + : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); + } + + // The below XOR does the following: + // Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing + // 1024 + x, then using hsub2 to subtract 1024 from that + static constexpr uint32_t or_mask = 0x64006400; + static constexpr uint32_t and_mask = 0x00F0000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) | or_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(or_mask), "n"(immLut)); + + // We will issue 2 hfmas that do the following: + // For the high FP16: + // Divide by 16 {packed as a operand} to get: + // 64 + x + // Subtract 64 {packed as c operand} to get x + // For the low FP16: + // we subtract 1024 {packed as c operand} to get x + + static constexpr uint32_t hfma_bias = 0xD400E400; // {-64, -1024} + static constexpr uint32_t hfma_scale = 0x2C003C00; // {1 / 16, 1} + + { + __half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(fp16x2_val, reinterpret_cast(hfma_scale), reinterpret_cast(hfma_bias)); + } + } + return reinterpret_cast(r); + } + friend class detail::VectorizedConverter; public: @@ -4108,6 +4507,260 @@ public: #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); + + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted_two = src_reg >> 2; + uint32_t src_reg_shifted_four = src_reg >> 4; + uint32_t src_reg_shifted_six = src_reg >> 6; + + // Modified prmt indices for signed 2-bit values + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + + static_assert(RegArray::kElements <= 8, "Too many inputs for I2 -> BF16 vector converter"); + + // First pass: extract and sign extend the 2-bit values + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2])); + + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii + 1]) + : "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2])); + } + + // For signed 2-bit integers: + // 00 -> 0 (0) + // 01 -> 1 (1) + // 10 -> -2 (2 with sign extension) + // 11 -> -1 (3 with sign extension) + //static constexpr uint32_t sign_mask = 0x00020002; // Mask to check sign bit + static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits + + // Modified for signed range (-2 to 1) + // We'll construct numbers in the form 128 + (x + 2) and then subtract 130 + // to get back to our original range + static constexpr uint32_t xor_mask = 0x43024302; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // Bias represents 130 in bfloat16 format + // Subtracting 130 brings us back to our signed range (-2 to 1) + static constexpr uint32_t bias_rep = 0x43024302; // {130, 130} in bfloat16 + const __nv_bfloat162& bias = reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_16 = Array; + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_16 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_16 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4, 8 or 16 to use private convert dispatch."); + + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted_two = src_reg >> 2; + uint32_t src_reg_shifted_four = src_reg >> 4; + uint32_t src_reg_shifted_six = src_reg >> 6; + + // Modified prmt indices for signed 2-bit values + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + + static_assert(RegArray::kElements <= 8, "Too many inputs for U2 -> BF16 vector converter"); + + // First pass: extract and sign extend the 2-bit values + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted_two), "r"(prmt_indices[ii / 2])); + + asm volatile( + "{ prmt.b32 %0, %1, %2, %3; }\n" + : "=r"(r[ii + 1]) + : "r"(src_reg_shifted_four), "r"(src_reg_shifted_six), "r"(prmt_indices[ii / 2])); + } + + static constexpr uint32_t and_mask = 0x00030003; // Mask for 2 bits + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{ lop3.b32 %0, %0, %1, %2, %3; }" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + static constexpr uint32_t bias_rep = xor_mask; // {128, 128} in bfloat16 + const __nv_bfloat162& bias = reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + /// Partial specialization for Array <= Array template struct NumericArrayConverter { @@ -4171,9 +4824,7 @@ private: CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < RegArray::kElements; ++ii) { asm volatile( - "{\n" - " prmt.b32 %0, %1, %2, %3;\n" - "}\n" + "{ prmt.b32 %0, %1, %2, %3; }\n" : "=r"(r[ii]) : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); } @@ -4185,6 +4836,133 @@ private: static constexpr uint32_t and_mask = 0x000F000F; static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{ lop3.b32 %0, %0, %1, %2, %3; }\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses bit tricks to construct a known FP16 number, then does a + // subtraction in FP16 for the final result. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, u4_21, 0x00, u4_10} + // fp16s_23 = {0x00, u4_43, 0x00, u4_32} + // fp16s_45 = {0x00, u4_65, 0x00, u4_54} + // fp16s_67 = {0x000, u4_7, 0x00, u4_76} + static constexpr uint32_t prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, "Too many inputs for BF16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + // For each operand, computes: // r[i] = (r[i] & and_mask) ^ xor_mask CUTLASS_PRAGMA_UNROLL @@ -4199,16 +4977,15 @@ private: // We will issue 2 bfmas that do the following: // high BF16: - // hi_bf16 - 136, lo_bf16 - 136 + // hi_bf16 - 128, lo_bf16 - 128 - // This is the BF16 {136, 136} represented as an integer. - static constexpr uint32_t bias_rep = 0x43084308; - const __nv_bfloat162& bias = reinterpret_cast(bias_rep); + // This is the BF16 {128, 128} represented as an integer. + static constexpr uint32_t bias = xor_mask; CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < RegArray::kElements; ++ii) { __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); - bf16x2_val = __hsub2(bf16x2_val, bias); + bf16x2_val = __hsub2(bf16x2_val, reinterpret_cast(bias)); } return reinterpret_cast(r); diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index 44d6cf9f3..a4b9e723b 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -644,11 +644,26 @@ struct GetName { static constexpr char name[] = "UNSUPPORTED"; }; +template <> +struct GetName { + static constexpr char name[] = "int2b_t"; +}; + +template <> +struct GetName { + static constexpr char name[] = "uint2b_t"; +}; + template <> struct GetName { static constexpr char name[] = "int4b_t"; }; +template <> +struct GetName { + static constexpr char name[] = "uint4b_t"; +}; + template <> struct GetName { static constexpr char name[] = "uint8_t"; @@ -709,9 +724,15 @@ using VectorConvertTypes = ::testing::Types< ResultSourcePair, ResultSourcePair, + ResultSourcePair, + ResultSourcePair, + ResultSourcePair, + ResultSourcePair, ResultSourcePair, ResultSourcePair, ResultSourcePair, + ResultSourcePair, + ResultSourcePair, ResultSourcePair >;