From 08754a58e48ea9db268efa14fee10d45db9abb64 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 13 Dec 2023 14:27:31 -0600 Subject: [PATCH] Fix the bugs (#1099) [ROCm/composable_kernel commit: 6891e4d10965513657d531c3c8c2048aaba34b05] --- include/ck/utility/type_convert.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 70bc6f278c..11db866152 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -182,7 +182,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion - return f8_convert_sr(type_convert(x)); + return f8_convert_sr(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -295,7 +295,7 @@ inline __host__ __device__ bf8_t f8_convert_rne(half_t x) template <> inline __host__ __device__ f8_t type_convert(float x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else return f8_convert_rne(x); @@ -352,10 +352,10 @@ inline __host__ __device__ half2_t type_convert(float2_t x) template <> inline __host__ __device__ f8_t type_convert(half_t x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else - return f8_convert_nre(x); + return f8_convert_rne(x); #endif } @@ -376,7 +376,7 @@ inline __host__ __device__ half_t type_convert(f8_t x) template <> inline __host__ __device__ bf8_t type_convert(float x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else return f8_convert_rne(x); @@ -403,7 +403,7 @@ inline __host__ __device__ float type_convert(bf8_t x) template <> inline __host__ __device__ bf8_t type_convert(half_t x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else return f8_convert_rne(x);