diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 2edbb7c789..37c9f3693b 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -19,13 +19,13 @@ #endif #if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \ - __HIP_DEVICE_COMPILE__ + (defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__) #define CK_FP8_CVT_FAST_PATH 1 #else #define CK_FP8_CVT_FAST_PATH 0 #endif -#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ +#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && (defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__) #define CK_OCP_FP8_CVT_FAST_PATH 1 #else #define CK_OCP_FP8_CVT_FAST_PATH 0 @@ -364,7 +364,8 @@ struct bf8_ocp_t __host__ explicit operator float() const #endif { -#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +//#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +#if CK_OCP_FP8_CVT_FAST_PATH return fp8_impl::cast_to_f32_from_f8(this->data); #else return fp8_impl::cast_from_f8( @@ -378,7 +379,8 @@ struct bf8_ocp_t __host__ explicit operator _Float16() const #endif { -#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +//#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +#if CK_OCP_FP8_CVT_FAST_PATH return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); #else return fp8_impl::cast_from_f8<_Float16, wm, we, false>( diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 53a905a7cd..30073efc99 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -71,7 +71,7 @@ inline __device__ bhalf2_t static_cast_float2_to_bhalf2_rne(float2_t x) // TODO: Why do we need the host instance? inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, float& y) { -#if defined(__gfx950__) +#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__) uint32_t result; asm volatile("v_cvt_pk_bf16_f32 %0, %1, %2" : "=v"(result) @@ -89,6 +89,7 @@ inline __host__ __device__ void static_cast_float_to_bhalf_packed(float& x, floa y_parts[1] = bf16_values[1]; #else // Skip conversion for non-GFX950 architectures + // TODO: Implement the conversion. x = static_cast(static_cast(x)); y = static_cast(static_cast(y)); #endif @@ -106,7 +107,7 @@ __host__ __device__ constexpr Y bf16_convert_rtn(X x); template <> inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) { -#if defined(__gfx950__) +#if defined(__gfx950__) && defined(__HIP_DEVICE_COMPILE__) return static_cast_float_to_bf16(x); #else // Nan check