[CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769)

* Change the return type of run_gemm_combinations in the basic tests

* Change the return type of run_gemm_combinations in the universal tests

* Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4

* Add universal GEMM test for fp8 x pk_i4

* Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4.

* Add missing GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>

* Add missing GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>

* No need for utility in test_ck_tile_elementwise_1d

* Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8

* Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant

* For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union

* Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced

* Convert from float to bf16 during compilation rather than using magic values

* Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8

* Comment out the basic test for fp16 x pk_i4 as it does not pass

* Add missing GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>

* Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8

* Add basic and universal GEMM tests for bf8 x pk_i4

* Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now

* Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now

* Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp

* Use explicit macros for enabling and disabling the the constexpr lookup based converters

* Fix two failing tests

* Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant

* Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16

* On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t
This commit is contained in:
SamiAario-AMD
2025-09-29 13:34:47 +03:00
committed by GitHub
parent e8842e3c1f
commit 0f10e6d921
16 changed files with 198 additions and 55 deletions

View File

@@ -117,12 +117,8 @@ using bf16_raw_t = uint16_t;
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
uint32_t bits = bit_cast<uint32_t>(f);
if(~bits & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
@@ -140,9 +136,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
else if(bits & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
@@ -152,9 +148,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
bits |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
return uint16_t(bits >> 16);
}
CK_TILE_HOST
@@ -225,24 +221,16 @@ uint16_t float_to_bf16_rta_asm(float f)
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16);
}
template <bf16_rounding_mode rounding>
@@ -287,7 +275,7 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined(__gfx950__)
#if CK_TILE_USE_LLVM_BUILTIN_BF16
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));

View File

@@ -7,9 +7,26 @@
#include <cstdint>
#include <type_traits>
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1
#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0
namespace ck_tile {
namespace element_wise {
// Generalized constexpr lookup table generator
template <typename T, std::size_t N, typename F, std::size_t... Is>
constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>)
{
return {func(Is)...};
}
template <typename T, std::size_t N, typename F>
constexpr std::array<T, N> make_lookup_table(F&& func)
{
return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{});
}
/**
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
@@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
*/
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16
// This approach fails validation in GEMM tests.
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
@@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
#else
// Lookup table for bf16_t values corresponding to int4 values -8 to 7
constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>(
[](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); });
return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf],
bf16_lookup_table[(q >> 16) & 0xf],
bf16_lookup_table[(q >> 4) & 0xf],
bf16_lookup_table[(q >> 20) & 0xf]};
#endif
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
/**
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
*
@@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for fp8_t values corresponding to int4 values -8 to 7
constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>(
[](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); });
return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf],
fp8_lookup_table[(q >> 16) & 0xf],
fp8_lookup_table[(q >> 4) & 0xf],
fp8_lookup_table[(q >> 20) & 0xf]};
}
#endif
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
{
@@ -224,6 +269,7 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
return res;
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
/**
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
*
@@ -285,6 +331,21 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for bf8_t values corresponding to int4 values -8 to 7
constexpr auto bf8_lookup_table = make_lookup_table<bf8_t, 16>(
[](int i) { return impl::cast_to_f8<float, bf8_t, true, false>(i - 8, 0); });
return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf],
bf8_lookup_table[(q >> 16) & 0xf],
bf8_lookup_table[(q >> 4) & 0xf],
bf8_lookup_table[(q >> 20) & 0xf]};
}
#endif
struct PassThroughPack8
{
@@ -300,17 +361,27 @@ struct PassThroughPack8
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_fp8x4(bit_cast<int>(x));
y.hi = i4_to_fp8x4(bit_cast<int>(x) >> 8);
#endif
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_bf8x4(bit_cast<int>(x));
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
#endif
}
constexpr const static bool is_pack8_invocable = true;
};