mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
* Fix fmha fwd precision issue on MI3XX series For fmha fwd fp16 cases, we found that using impl::cast_tile_pk_fp16_fp32 for casting P would lead to precision issues, since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero. For examaple, fixing K,V to be all 1, and Q is random, which outputs are expected to be all 1. But we found that it would have some incorrect outputs 0.9995, which are smaller than the atol 0.001. (1 - 0.9995 = 0.0005 < 0.001) Thus, ck do not report this error. * Add option to switch rtn/rtz for fmha fwd