mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] fix pk_fp4 compilation for non-gfx950 GPUs (#2983)
See build error log from https://github.com/ROCm/composable_kernel/issues/2271#issuecomment-3150218542 This PR make vector element access constexpr-safe by avoiding operator[] on ext_vector_type(2) and replace those sites in the pk_fp4 conversions so they can be used in constant expressions, as The operator[] on ext_vector_type(2) isn't allowed in constant expressions, which caused "constexpr function never produces a constant expression" with a note at x[0]. Using `bit_cast` to a trivial array representation keeps it constexpr-compatible. Signed-off-by: Hollow Man <hollowman@opensuse.org>
This commit is contained in:
@@ -23,6 +23,51 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
// Helpers: constexpr-safe access to elements of ext_vector_type(2)
|
||||
// Some compilers don't allow operator[] in constant expressions for vector types.
|
||||
// We use bit_cast to a trivially copyable representation to extract lanes.
|
||||
namespace detail {
|
||||
struct fp16x2_repr
|
||||
{
|
||||
_Float16 e[2];
|
||||
};
|
||||
struct bf16x2_repr
|
||||
{
|
||||
bfloat16_t e[2];
|
||||
};
|
||||
struct fp32x2_repr
|
||||
{
|
||||
float e[2];
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr _Float16 lane0(const fp16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp16x2_repr>(v).e[0];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr _Float16 lane1(const fp16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp16x2_repr>(v).e[1];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t lane0(const bf16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<bf16x2_repr>(v).e[0];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t lane1(const bf16x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<bf16x2_repr>(v).e[1];
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float lane0(const fp32x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp32x2_repr>(v).e[0];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float lane1(const fp32x2_t& v)
|
||||
{
|
||||
return ck_tile::bit_cast<fp32x2_repr>(v).e[1];
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
struct pk_float4_e2m1_t;
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f);
|
||||
|
||||
@@ -166,15 +211,24 @@ template <typename T>
|
||||
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
|
||||
{
|
||||
fp32x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
|
||||
return detail::lane0(tmp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
|
||||
{
|
||||
fp16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
|
||||
return detail::lane0(tmp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
|
||||
{
|
||||
bf16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
|
||||
return detail::lane0(tmp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
|
||||
else
|
||||
@@ -192,7 +246,8 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
|
||||
cvt.u32, detail::lane0(src), detail::lane1(src), scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
@@ -269,7 +324,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
|
||||
float_to_mxfp4(detail::lane1(x), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
@@ -277,7 +333,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
|
||||
float_to_mxfp4(detail::lane1(x), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
@@ -285,7 +342,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
|
||||
return pk_fp4_t::_pack(float_to_mxfp4(detail::lane0(x), scale),
|
||||
float_to_mxfp4(detail::lane1(x), scale));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user