mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Fix build issues when __gfx950__ macro is enabled.
This commit is contained in:
@@ -19,13 +19,13 @@
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
|
||||
__HIP_DEVICE_COMPILE__
|
||||
(defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__)
|
||||
#define CK_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_FP8_CVT_FAST_PATH 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
|
||||
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && (defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__)
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
#define CK_OCP_FP8_CVT_FAST_PATH 0
|
||||
@@ -364,7 +364,8 @@ struct bf8_ocp_t
|
||||
__host__ explicit operator float() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
//#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<float, wm, we, false>(
|
||||
@@ -378,7 +379,8 @@ struct bf8_ocp_t
|
||||
__host__ explicit operator _Float16() const
|
||||
#endif
|
||||
{
|
||||
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
//#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if CK_OCP_FP8_CVT_FAST_PATH
|
||||
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
|
||||
#else
|
||||
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
|
||||
|
||||
@@ -71,7 +71,7 @@ inline __device__ bhalf2_t static_cast_float2_to_bhalf2_rne(float2_t x)
|
||||
// TODO: Why do we need the host instance?
|
||||
inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, float& y)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__)
|
||||
uint32_t result;
|
||||
asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2"
|
||||
: "=v"(result)
|
||||
@@ -89,6 +89,7 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, floa
|
||||
y_parts[1] = bf16_values[1];
|
||||
#else
|
||||
// Skip conversion for non-GFX950 architectures
|
||||
// TODO: Implement the conversion.
|
||||
x = static_cast<float>(static_cast<bhalf_t>(x));
|
||||
y = static_cast<float>(static_cast<bhalf_t>(y));
|
||||
#endif
|
||||
@@ -106,7 +107,7 @@ __host__ __device__ constexpr Y bf16_convert_rtn(X x);
|
||||
template <>
|
||||
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__)
|
||||
return static_cast_float_to_bf16(x);
|
||||
#else
|
||||
// Nan check
|
||||
|
||||
Reference in New Issue
Block a user