mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Fixed f8_gemm NaN (#975)
* workaround nan problem by changing output to fp16
* enable f8/bf8 gemm tests on MI200
* workaround f16 to f8 conversion
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: ac9595a9f1]
This commit is contained in:
@@ -173,8 +173,7 @@ struct PassThrough
|
||||
template <>
|
||||
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
|
||||
{
|
||||
// to-do: fix half_t to bf8_t convert
|
||||
y = ck::type_convert<bf8_t>(ck::type_convert<float>(x));
|
||||
y = ck::type_convert<bf8_t>(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -344,7 +344,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
#elif 0
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
@@ -353,6 +353,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<f8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -393,7 +395,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
// convert to float and use native converion
|
||||
return f8_convert_sr<f8_t>(type_convert<float>(x));
|
||||
#else
|
||||
#elif 0
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
@@ -403,6 +405,8 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
x, rng);
|
||||
#else
|
||||
return type_convert<bf8_t>(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user