Merge commit '0f10e6d9218ce9d00a34a66572c0686dce1e45ea' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-29 11:12:04 +00:00
parent 2593ecf5b5
commit f9767142cf
16 changed files with 198 additions and 55 deletions

View File

@@ -7,9 +7,26 @@
#include <cstdint>
#include <type_traits>
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1
#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0
namespace ck_tile {
namespace element_wise {
// Generalized constexpr lookup table generator
template <typename T, std::size_t N, typename F, std::size_t... Is>
constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>)
{
return {func(Is)...};
}
template <typename T, std::size_t N, typename F>
constexpr std::array<T, N> make_lookup_table(F&& func)
{
return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{});
}
/**
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
@@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
*/
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16
// This approach fails validation in GEMM tests.
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
@@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
#else
// Lookup table for bf16_t values corresponding to int4 values -8 to 7
constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>(
[](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); });
return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf],
bf16_lookup_table[(q >> 16) & 0xf],
bf16_lookup_table[(q >> 4) & 0xf],
bf16_lookup_table[(q >> 20) & 0xf]};
#endif
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
/**
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
*
@@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for fp8_t values corresponding to int4 values -8 to 7
constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>(
[](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); });
return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf],
fp8_lookup_table[(q >> 16) & 0xf],
fp8_lookup_table[(q >> 4) & 0xf],
fp8_lookup_table[(q >> 20) & 0xf]};
}
#endif
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
{
@@ -224,6 +269,7 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
return res;
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
/**
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
*
@@ -285,6 +331,21 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for bf8_t values corresponding to int4 values -8 to 7
constexpr auto bf8_lookup_table = make_lookup_table<bf8_t, 16>(
[](int i) { return impl::cast_to_f8<float, bf8_t, true, false>(i - 8, 0); });
return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf],
bf8_lookup_table[(q >> 16) & 0xf],
bf8_lookup_table[(q >> 4) & 0xf],
bf8_lookup_table[(q >> 20) & 0xf]};
}
#endif
struct PassThroughPack8
{
@@ -300,17 +361,27 @@ struct PassThroughPack8
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_fp8x4(bit_cast<int>(x));
y.hi = i4_to_fp8x4(bit_cast<int>(x) >> 8);
#endif
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_bf8x4(bit_cast<int>(x));
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
#endif
}
constexpr const static bool is_pack8_invocable = true;
};