mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
workaround with float (#992)
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: 39430bfdeb]
This commit is contained in:
@@ -146,7 +146,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
#elif 0
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -154,6 +154,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -164,9 +166,11 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
#elif 0
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
|
||||
#else
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -222,7 +226,7 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
#else
|
||||
#elif 0
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
|
||||
@@ -230,6 +234,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -240,9 +246,11 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// use native conversion to float and convert to fp16
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#else
|
||||
#elif 0
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
|
||||
#else
|
||||
return type_convert<half_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -354,7 +362,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -406,7 +414,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
return f8_convert_sr<bf8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user