mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Fix fmha fwd precision issue on MI3XX series (#2285)
* Fix fmha fwd precision issue on MI3XX series
For fmha fwd fp16 cases, we found that using
impl::cast_tile_pk_fp16_fp32 for casting P would lead to precision
issues, since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
For examaple, fixing K,V to be all 1, and Q is random, which outputs are
expected to be all 1. But we found that it would have some incorrect
outputs 0.9995, which are smaller than the atol 0.001. (1 - 0.9995 =
0.0005 < 0.001) Thus, ck do not report this error.
* Add option to switch rtn/rtz for fmha fwd
[ROCm/composable_kernel commit: 9fcf21a4ec]
This commit is contained in:
@@ -223,6 +223,10 @@
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
#define CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
|
||||
@@ -702,12 +702,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pk_fp16_fp32 would cause precision issue,
|
||||
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#else
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
|
||||
@@ -653,12 +653,19 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pk_fp16_fp32 would cause precision issue,
|
||||
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#else
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
|
||||
Reference in New Issue
Block a user