[CK_TILE] fix pk_fp4 compilation for non-gfx950 GPUs (#2983)

See build error log from
https://github.com/ROCm/composable_kernel/issues/2271#issuecomment-3150218542

This PR make vector element access constexpr-safe by avoiding operator[] on
ext_vector_type(2) and replace those sites in the pk_fp4 conversions so they
can be used in constant expressions, as The operator[] on ext_vector_type(2)
isn't allowed in constant expressions, which caused "constexpr function never
produces a constant expression" with a note at x[0]. Using `bit_cast` to a
trivial array representation keeps it constexpr-compatible.

Signed-off-by: Hollow Man <hollowman@opensuse.org>
This commit is contained in:
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟
2025-10-09 17:43:41 +03:00
committed by GitHub
parent 7b6451b68e
commit fb66b4f5e4

View File

@@ -23,6 +23,51 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
// Helpers: constexpr-safe access to elements of ext_vector_type(2)
// Some compilers don't allow operator[] in constant expressions for vector types.
// We use bit_cast to a trivially copyable representation to extract lanes.
namespace detail {
struct fp16x2_repr
{
_Float16 e[2];
};
struct bf16x2_repr
{
bfloat16_t e[2];
};
struct fp32x2_repr
{
float e[2];
};
CK_TILE_HOST_DEVICE constexpr _Float16 lane0(const fp16x2_t& v)
{
return ck_tile::bit_cast<fp16x2_repr>(v).e[0];
}
CK_TILE_HOST_DEVICE constexpr _Float16 lane1(const fp16x2_t& v)
{
return ck_tile::bit_cast<fp16x2_repr>(v).e[1];
}
CK_TILE_HOST_DEVICE constexpr bfloat16_t lane0(const bf16x2_t& v)
{
return ck_tile::bit_cast<bf16x2_repr>(v).e[0];
}
CK_TILE_HOST_DEVICE constexpr bfloat16_t lane1(const bf16x2_t& v)
{
return ck_tile::bit_cast<bf16x2_repr>(v).e[1];
}
CK_TILE_HOST_DEVICE constexpr float lane0(const fp32x2_t& v)
{
return ck_tile::bit_cast<fp32x2_repr>(v).e[0];
}
CK_TILE_HOST_DEVICE constexpr float lane1(const fp32x2_t& v)
{
return ck_tile::bit_cast<fp32x2_repr>(v).e[1];
}
} // namespace detail
struct pk_float4_e2m1_t;
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f);
@@ -166,15 +211,24 @@ template <typename T>
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
{
if constexpr(std::is_same_v<T, fp32_t>)
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
{
fp32x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
return detail::lane0(tmp);
}
else if constexpr(std::is_same_v<T, fp32x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
else if constexpr(std::is_same_v<T, fp16_t>)
return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
{
fp16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
return detail::lane0(tmp);
}
else if constexpr(std::is_same_v<T, fp16x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
else if constexpr(std::is_same_v<T, bf16_t>)
return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
{
bf16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
return detail::lane0(tmp);
}
else if constexpr(std::is_same_v<T, bf16x2_t>)
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
else
@@ -192,7 +246,8 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
if constexpr(std::is_same_v<T, fp32_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
else if constexpr(std::is_same_v<T, fp32x2_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
cvt.u32, detail::lane0(src), detail::lane1(src), scale, 0);
else if constexpr(std::is_same_v<T, fp16_t>)
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
else if constexpr(std::is_same_v<T, fp16x2_t>)
@@ -269,7 +324,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
float_to_mxfp4(detail::lane1(x), scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
@@ -277,7 +333,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
float_to_mxfp4(detail::lane1(x), scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
@@ -285,7 +342,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
#else
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
float_to_mxfp4(detail::lane1(x), scale));
#endif
}