Fix build issues when __gfx950__ macro is enabled.

This commit is contained in:
Ville Pietilä
2025-08-08 08:01:42 +00:00
parent 4b8a559da9
commit c47b80580d
2 changed files with 9 additions and 6 deletions

View File

@@ -19,13 +19,13 @@
#endif
#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__
(defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__)
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && (defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__)
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
@@ -364,7 +364,8 @@ struct bf8_ocp_t
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
//#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
@@ -378,7 +379,8 @@ struct bf8_ocp_t
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
//#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if CK_OCP_FP8_CVT_FAST_PATH
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(

View File

@@ -71,7 +71,7 @@ inline __device__ bhalf2_t static_cast_float2_to_bhalf2_rne(float2_t x)
// TODO: Why do we need the host instance?
inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, float& y)
{
#if defined(__gfx950__)
#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__)
uint32_t result;
asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2"
: "=v"(result)
@@ -89,6 +89,7 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, floa
y_parts[1] = bf16_values[1];
#else
// Skip conversion for non-GFX950 architectures
// TODO: Implement the conversion.
x = static_cast<float>(static_cast<bhalf_t>(x));
y = static_cast<float>(static_cast<bhalf_t>(y));
#endif
@@ -106,7 +107,7 @@ __host__ __device__ constexpr Y bf16_convert_rtn(X x);
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
#if defined(__gfx950__)
#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__)
return static_cast_float_to_bf16(x);
#else
// Nan check