[CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769)

* Change the return type of run_gemm_combinations in the basic tests

* Change the return type of run_gemm_combinations in the universal tests

* Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4

* Add universal GEMM test for fp8 x pk_i4

* Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4.

* Add missing GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>

* Add missing GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>

* No need for utility in test_ck_tile_elementwise_1d

* Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8

* Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant

* For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union

* Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced

* Convert from float to bf16 during compilation rather than using magic values

* Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8

* Comment out the basic test for fp16 x pk_i4 as it does not pass

* Add missing GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>

* Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8

* Add basic and universal GEMM tests for bf8 x pk_i4

* Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now

* Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now

* Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp

* Use explicit macros for enabling and disabling the the constexpr lookup based converters

* Fix two failing tests

* Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant

* Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16

* On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t
This commit is contained in:
SamiAario-AMD
2025-09-29 13:34:47 +03:00
committed by GitHub
parent e8842e3c1f
commit 0f10e6d921
16 changed files with 198 additions and 55 deletions

View File

@@ -7,9 +7,26 @@
#include <cstdint>
#include <type_traits>
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1
#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0
namespace ck_tile {
namespace element_wise {
// Generalized constexpr lookup table generator
template <typename T, std::size_t N, typename F, std::size_t... Is>
constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>)
{
return {func(Is)...};
}
template <typename T, std::size_t N, typename F>
constexpr std::array<T, N> make_lookup_table(F&& func)
{
return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{});
}
/**
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
@@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
*/
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16
// This approach fails validation in GEMM tests.
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
@@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
#else
// Lookup table for bf16_t values corresponding to int4 values -8 to 7
constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>(
[](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); });
return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf],
bf16_lookup_table[(q >> 16) & 0xf],
bf16_lookup_table[(q >> 4) & 0xf],
bf16_lookup_table[(q >> 20) & 0xf]};
#endif
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
/**
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
*
@@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for fp8_t values corresponding to int4 values -8 to 7
constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>(
[](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); });
return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf],
fp8_lookup_table[(q >> 16) & 0xf],
fp8_lookup_table[(q >> 4) & 0xf],
fp8_lookup_table[(q >> 20) & 0xf]};
}
#endif
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
{
@@ -224,6 +269,7 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
return res;
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
/**
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
*
@@ -285,6 +331,21 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for bf8_t values corresponding to int4 values -8 to 7
constexpr auto bf8_lookup_table = make_lookup_table<bf8_t, 16>(
[](int i) { return impl::cast_to_f8<float, bf8_t, true, false>(i - 8, 0); });
return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf],
bf8_lookup_table[(q >> 16) & 0xf],
bf8_lookup_table[(q >> 4) & 0xf],
bf8_lookup_table[(q >> 20) & 0xf]};
}
#endif
struct PassThroughPack8
{
@@ -300,17 +361,27 @@ struct PassThroughPack8
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_fp8x4(bit_cast<int>(x));
y.hi = i4_to_fp8x4(bit_cast<int>(x) >> 8);
#endif
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_bf8x4(bit_cast<int>(x));
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
#endif
}
constexpr const static bool is_pack8_invocable = true;
};