Fix the fp8 gemm for large tensors on MI300. (#1011)

* Fix the fp8 conversion

* Try clipping value before conversion

* Fix return

* Simplify with a const

* reduce the gemm input tensor values to reduce round-off error

* replace if-else with lambda

* fix syntax

---------

Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com>
This commit is contained in:
Illia Silin
2023-10-27 21:10:47 -07:00
committed by GitHub
parent 6fe0bc7e72
commit f46a6ffad8
4 changed files with 10 additions and 8 deletions

View File

@@ -100,6 +100,8 @@ template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;