From c81483b230cc2e3252f1805704c9ff227f2882bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=8D=F0=9D=95=A0=F0=9D=95=9D=F0=9D=95=9D=F0=9D=95=A0?= =?UTF-8?q?=F0=9D=95=A8=20=F0=9D=95=84=F0=9D=95=92=F0=9D=95=9F?= Date: Thu, 9 Oct 2025 17:43:41 +0300 Subject: [PATCH] [CK_TILE] fix pk_fp4 compilation for non-gfx950 GPUs (#2983) See build error log from https://github.com/ROCm/composable_kernel/issues/2271#issuecomment-3150218542 This PR make vector element access constexpr-safe by avoiding operator[] on ext_vector_type(2) and replace those sites in the pk_fp4 conversions so they can be used in constant expressions, as The operator[] on ext_vector_type(2) isn't allowed in constant expressions, which caused "constexpr function never produces a constant expression" with a note at x[0]. Using `bit_cast` to a trivial array representation keeps it constexpr-compatible. Signed-off-by: Hollow Man [ROCm/composable_kernel commit: fb66b4f5e4b5b178e3eee04189224e139e939c0c] --- include/ck_tile/core/numeric/pk_fp4.hpp | 72 ++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index 8b78990d08..4f662095db 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -23,6 +23,51 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +// Helpers: constexpr-safe access to elements of ext_vector_type(2) +// Some compilers don't allow operator[] in constant expressions for vector types. +// We use bit_cast to a trivially copyable representation to extract lanes. +namespace detail { +struct fp16x2_repr +{ + _Float16 e[2]; +}; +struct bf16x2_repr +{ + bfloat16_t e[2]; +}; +struct fp32x2_repr +{ + float e[2]; +}; + +CK_TILE_HOST_DEVICE constexpr _Float16 lane0(const fp16x2_t& v) +{ + return ck_tile::bit_cast(v).e[0]; +} +CK_TILE_HOST_DEVICE constexpr _Float16 lane1(const fp16x2_t& v) +{ + return ck_tile::bit_cast(v).e[1]; +} + +CK_TILE_HOST_DEVICE constexpr bfloat16_t lane0(const bf16x2_t& v) +{ + return ck_tile::bit_cast(v).e[0]; +} +CK_TILE_HOST_DEVICE constexpr bfloat16_t lane1(const bf16x2_t& v) +{ + return ck_tile::bit_cast(v).e[1]; +} + +CK_TILE_HOST_DEVICE constexpr float lane0(const fp32x2_t& v) +{ + return ck_tile::bit_cast(v).e[0]; +} +CK_TILE_HOST_DEVICE constexpr float lane1(const fp32x2_t& v) +{ + return ck_tile::bit_cast(v).e[1]; +} +} // namespace detail + struct pk_float4_e2m1_t; CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f); @@ -166,15 +211,24 @@ template CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f) { if constexpr(std::is_same_v) - return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0]; + { + fp32x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0); + return detail::lane0(tmp); + } else if constexpr(std::is_same_v) return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0); else if constexpr(std::is_same_v) - return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0]; + { + fp16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0); + return detail::lane0(tmp); + } else if constexpr(std::is_same_v) return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0); else if constexpr(std::is_same_v) - return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0]; + { + bf16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0); + return detail::lane0(tmp); + } else if constexpr(std::is_same_v) return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0); else @@ -192,7 +246,8 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f) if constexpr(std::is_same_v) cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0); else if constexpr(std::is_same_v) - cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0); + cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( + cvt.u32, detail::lane0(src), detail::lane1(src), scale, 0); else if constexpr(std::is_same_v) cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0); else if constexpr(std::is_same_v) @@ -269,7 +324,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale), + float_to_mxfp4(detail::lane1(x), scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale) @@ -277,7 +333,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale), + float_to_mxfp4(detail::lane1(x), scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale) @@ -285,7 +342,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale), + float_to_mxfp4(detail::lane1(x), scale)); #endif }