diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index b7dca9dd0a..0dee750b69 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -55,8 +55,8 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const; template - CK_TILE_HOST_DEVICE raw_type unpack(number) const; - CK_TILE_HOST_DEVICE static pk_float4_e2m1_t pack(const type x0, const type x1) + CK_TILE_HOST_DEVICE constexpr raw_type unpack(number) const; + CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1) { return (x1 << 4) | (x0 & 0b00001111); } @@ -130,7 +130,7 @@ struct numeric }; template -CK_TILE_HOST_DEVICE pk_fp4_raw_t pk_fp4_t::unpack(number) const +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number) const { static_assert(I < 2, "Index is out of range."); if constexpr(I == 1) @@ -147,7 +147,6 @@ namespace impl { template CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f) { - // TODO: check the order if constexpr(std::is_same_v) return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0]; else if constexpr(std::is_same_v) @@ -167,7 +166,6 @@ CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f) template CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f) { - // TODO: check the order union { uint32_t u32;