Merge commit '0f10e6d9218ce9d00a34a66572c0686dce1e45ea' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-29 11:12:04 +00:00
parent 2593ecf5b5
commit f9767142cf
16 changed files with 198 additions and 55 deletions

View File

@@ -117,12 +117,8 @@ using bf16_raw_t = uint16_t;
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000)
uint32_t bits = bit_cast<uint32_t>(f);
if(~bits & 0x7f800000)
{
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
@@ -140,9 +136,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
}
else if(u.int32 & 0xffff)
else if(bits & 0xffff)
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
@@ -152,9 +148,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
bits |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
return uint16_t(bits >> 16);
}
CK_TILE_HOST
@@ -225,24 +221,16 @@ uint16_t float_to_bf16_rta_asm(float f)
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
}
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_raw(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16);
uint32_t bits = bit_cast<uint32_t>(f);
return static_cast<uint16_t>(bits >> 16);
}
template <bf16_rounding_mode rounding>
@@ -287,7 +275,7 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if defined(__gfx950__)
#if CK_TILE_USE_LLVM_BUILTIN_BF16
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));

View File

@@ -7,9 +7,26 @@
#include <cstdint>
#include <type_traits>
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1
#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0
namespace ck_tile {
namespace element_wise {
// Generalized constexpr lookup table generator
template <typename T, std::size_t N, typename F, std::size_t... Is>
constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>)
{
return {func(Is)...};
}
template <typename T, std::size_t N, typename F>
constexpr std::array<T, N> make_lookup_table(F&& func)
{
return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{});
}
/**
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
@@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
*/
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16
// This approach fails validation in GEMM tests.
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
@@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
#else
// Lookup table for bf16_t values corresponding to int4 values -8 to 7
constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>(
[](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); });
return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf],
bf16_lookup_table[(q >> 16) & 0xf],
bf16_lookup_table[(q >> 4) & 0xf],
bf16_lookup_table[(q >> 20) & 0xf]};
#endif
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
/**
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
*
@@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for fp8_t values corresponding to int4 values -8 to 7
constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>(
[](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); });
return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf],
fp8_lookup_table[(q >> 16) & 0xf],
fp8_lookup_table[(q >> 4) & 0xf],
fp8_lookup_table[(q >> 20) & 0xf]};
}
#endif
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
{
@@ -224,6 +269,7 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
return res;
}
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
/**
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
*
@@ -285,6 +331,21 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
#else
CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
{
// The approach below can be used once this compiler issue is resolved:
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
// Lookup table for bf8_t values corresponding to int4 values -8 to 7
constexpr auto bf8_lookup_table = make_lookup_table<bf8_t, 16>(
[](int i) { return impl::cast_to_f8<float, bf8_t, true, false>(i - 8, 0); });
return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf],
bf8_lookup_table[(q >> 16) & 0xf],
bf8_lookup_table[(q >> 4) & 0xf],
bf8_lookup_table[(q >> 20) & 0xf]};
}
#endif
struct PassThroughPack8
{
@@ -300,17 +361,27 @@ struct PassThroughPack8
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_fp8x4(bit_cast<int>(x));
y.hi = i4_to_fp8x4(bit_cast<int>(x) >> 8);
#endif
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
#else
y.lo = i4_to_bf8x4(bit_cast<int>(x));
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
#endif
}
constexpr const static bool is_pack8_invocable = true;
};