add constexpr to pk_fp4::pack/unpack() (#2586)

This commit is contained in:
Gino Lu
2025-07-30 22:29:04 +08:00
committed by GitHub
parent 61e21f5567
commit b25d512e8a

View File

@@ -55,8 +55,8 @@ struct pk_float4_e2m1_t
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
template <index_t I>
CK_TILE_HOST_DEVICE raw_type unpack(number<I>) const;
CK_TILE_HOST_DEVICE static pk_float4_e2m1_t pack(const type x0, const type x1)
CK_TILE_HOST_DEVICE constexpr raw_type unpack(number<I>) const;
CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
@@ -130,7 +130,7 @@ struct numeric<pk_fp4_t>
};
template <index_t I>
CK_TILE_HOST_DEVICE pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 1)
@@ -147,7 +147,6 @@ namespace impl {
template <typename T>
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
{
// TODO: check the order
if constexpr(std::is_same_v<T, fp32_t>)
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
else if constexpr(std::is_same_v<T, fp32x2_t>)
@@ -167,7 +166,6 @@ CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
template <typename T>
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
{
// TODO: check the order
union
{
uint32_t u32;