mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769)
* Change the return type of run_gemm_combinations in the basic tests * Change the return type of run_gemm_combinations in the universal tests * Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4 * Add universal GEMM test for fp8 x pk_i4 * Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4. * Add missing GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t> * Add missing GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t> * No need for utility in test_ck_tile_elementwise_1d * Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8 * Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant * For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union * Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced * Convert from float to bf16 during compilation rather than using magic values * Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8 * Comment out the basic test for fp16 x pk_i4 as it does not pass * Add missing GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t> * Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8 * Add basic and universal GEMM tests for bf8 x pk_i4 * Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now * Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now * Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp * Use explicit macros for enabling and disabling the the constexpr lookup based converters * Fix two failing tests * Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant * Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16 * On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t
This commit is contained in:
@@ -117,12 +117,8 @@ using bf16_raw_t = uint16_t;
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
if(~u.int32 & 0x7f800000)
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
if(~bits & 0x7f800000)
|
||||
{
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
@@ -140,9 +136,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
|
||||
bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
|
||||
}
|
||||
else if(u.int32 & 0xffff)
|
||||
else if(bits & 0xffff)
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
@@ -152,9 +148,9 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
u.int32 |= 0x10000; // Preserve signaling NaN
|
||||
bits |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
return uint16_t(u.int32 >> 16);
|
||||
return uint16_t(bits >> 16);
|
||||
}
|
||||
|
||||
CK_TILE_HOST
|
||||
@@ -225,24 +221,16 @@ uint16_t float_to_bf16_rta_asm(float f)
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
|
||||
}
|
||||
|
||||
// Fast truncate instead of rounding, RTZ
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
return uint16_t(u.int32 >> 16);
|
||||
uint32_t bits = bit_cast<uint32_t>(f);
|
||||
return static_cast<uint16_t>(bits >> 16);
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
@@ -287,7 +275,7 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
|
||||
@@ -7,9 +7,26 @@
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF16 1
|
||||
#define CONSTEXPR_LOOKUP_TABLE_FOR_FP8 0
|
||||
#define CONSTEXPR_LOOKUP_TABLE_FOR_BF8 0
|
||||
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
// Generalized constexpr lookup table generator
|
||||
template <typename T, std::size_t N, typename F, std::size_t... Is>
|
||||
constexpr std::array<T, N> make_lookup_table_impl(F&& func, std::index_sequence<Is...>)
|
||||
{
|
||||
return {func(Is)...};
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, typename F>
|
||||
constexpr std::array<T, N> make_lookup_table(F&& func)
|
||||
{
|
||||
return make_lookup_table_impl<T, N>(std::forward<F>(func), std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
|
||||
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
|
||||
@@ -121,6 +138,8 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
|
||||
*/
|
||||
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF16
|
||||
// This approach fails validation in GEMM tests.
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
@@ -146,8 +165,19 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
|
||||
|
||||
return res;
|
||||
#else
|
||||
// Lookup table for bf16_t values corresponding to int4 values -8 to 7
|
||||
constexpr auto bf16_lookup_table = make_lookup_table<bf16_t, 16>(
|
||||
[](int i) { return bit_cast<bf16_t>(float_to_bf16_rtn_raw(i - 8)); });
|
||||
|
||||
return bf16x4_t{bf16_lookup_table[(q >> 0) & 0xf],
|
||||
bf16_lookup_table[(q >> 16) & 0xf],
|
||||
bf16_lookup_table[(q >> 4) & 0xf],
|
||||
bf16_lookup_table[(q >> 20) & 0xf]};
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
|
||||
/**
|
||||
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
|
||||
*
|
||||
@@ -209,6 +239,21 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
|
||||
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
|
||||
}
|
||||
#else
|
||||
CK_TILE_DEVICE fp8x4_t i4_to_fp8x4(int q)
|
||||
{
|
||||
// The approach below can be used once this compiler issue is resolved:
|
||||
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
|
||||
// Lookup table for fp8_t values corresponding to int4 values -8 to 7
|
||||
constexpr auto fp8_lookup_table = make_lookup_table<fp8_t, 16>(
|
||||
[](int i) { return impl::cast_to_f8<float, fp8_t, true, false>(i - 8, 0); });
|
||||
|
||||
return fp8x4_t{fp8_lookup_table[(q >> 0) & 0xf],
|
||||
fp8_lookup_table[(q >> 16) & 0xf],
|
||||
fp8_lookup_table[(q >> 4) & 0xf],
|
||||
fp8_lookup_table[(q >> 20) & 0xf]};
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
|
||||
{
|
||||
@@ -224,6 +269,7 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
|
||||
return res;
|
||||
}
|
||||
|
||||
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
|
||||
/**
|
||||
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
|
||||
*
|
||||
@@ -285,6 +331,21 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
|
||||
|
||||
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
|
||||
}
|
||||
#else
|
||||
CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
|
||||
{
|
||||
// The approach below can be used once this compiler issue is resolved:
|
||||
// "constexpr bit cast involving type 'unsigned _BitInt(8)' is not yet supported"
|
||||
// Lookup table for bf8_t values corresponding to int4 values -8 to 7
|
||||
constexpr auto bf8_lookup_table = make_lookup_table<bf8_t, 16>(
|
||||
[](int i) { return impl::cast_to_f8<float, bf8_t, true, false>(i - 8, 0); });
|
||||
|
||||
return bf8x4_t{bf8_lookup_table[(q >> 0) & 0xf],
|
||||
bf8_lookup_table[(q >> 16) & 0xf],
|
||||
bf8_lookup_table[(q >> 4) & 0xf],
|
||||
bf8_lookup_table[(q >> 20) & 0xf]};
|
||||
}
|
||||
#endif
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
@@ -300,17 +361,27 @@ struct PassThroughPack8
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y.lo = i4_to_bhalf4(bit_cast<int>(x));
|
||||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
#if !CONSTEXPR_LOOKUP_TABLE_FOR_FP8
|
||||
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
|
||||
#else
|
||||
y.lo = i4_to_fp8x4(bit_cast<int>(x));
|
||||
y.hi = i4_to_fp8x4(bit_cast<int>(x) >> 8);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
#if !CONSTEXPR_LOOKUP_TABLE_FOR_BF8
|
||||
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
|
||||
#else
|
||||
y.lo = i4_to_bf8x4(bit_cast<int>(x));
|
||||
y.hi = i4_to_bf8x4(bit_cast<int>(x) >> 8);
|
||||
#endif
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user