From 6fe4f27c0511c7d89b9251c3a7eff846946d64ec Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 16 Oct 2023 17:42:59 -0500 Subject: [PATCH] workaround with float (#992) Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: 39430bfdeb7247f0fdf4ba07559619ab1ab9a415] --- include/ck/utility/type_convert.hpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 8b70a6bfb4..ccbd5db644 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -146,7 +146,7 @@ inline __host__ __device__ f8_t type_convert(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return type_convert(type_convert(x)); -#else +#elif 0 constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -154,6 +154,8 @@ inline __host__ __device__ f8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); +#else + return type_convert(type_convert(x)); #endif } @@ -164,9 +166,11 @@ inline __host__ __device__ half_t type_convert(f8_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); -#else +#elif 0 constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#else + return type_convert(type_convert(x)); #endif } #endif @@ -222,7 +226,7 @@ inline __host__ __device__ bf8_t type_convert(half_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion return type_convert(type_convert(x)); -#else +#elif 0 constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -230,6 +234,8 @@ inline __host__ __device__ bf8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); +#else + return type_convert(type_convert(x)); #endif } @@ -240,9 +246,11 @@ inline __host__ __device__ half_t type_convert(bf8_t x) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); -#else +#elif 0 constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#else + return type_convert(type_convert(x)); #endif } #endif @@ -354,7 +362,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) cast_to_f8( x, rng); #else - return type_convert(type_convert(x)); + return f8_convert_sr(type_convert(x)); #endif } #endif @@ -406,7 +414,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) cast_to_f8( x, rng); #else - return type_convert(type_convert(x)); + return f8_convert_sr(type_convert(x)); #endif } #endif