From 7cb7aa8e7527ec831b5cb48c86fc65673b0a1704 Mon Sep 17 00:00:00 2001 From: MHYangAMD Date: Tue, 10 Jun 2025 15:03:23 +0800 Subject: [PATCH] Fix fmha fwd precision issue on MI3XX series (#2285) * Fix fmha fwd precision issue on MI3XX series For fmha fwd fp16 cases, we found that using impl::cast_tile_pk_fp16_fp32 for casting P would lead to precision issues, since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero. For examaple, fixing K,V to be all 1, and Q is random, which outputs are expected to be all 1. But we found that it would have some incorrect outputs 0.9995, which are smaller than the atol 0.001. (1 - 0.9995 = 0.0005 < 0.001) Thus, ck do not report this error. * Add option to switch rtn/rtz for fmha fwd [ROCm/composable_kernel commit: 9fcf21a4ec4698209c4ed7b859574cc1e1986aa3] --- include/ck_tile/core/config.hpp | 4 ++++ .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 7 +++++++ .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 7 +++++++ 3 files changed, 18 insertions(+) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 27133fa847..14b33aea77 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -223,6 +223,10 @@ #define CK_TILE_FMHA_FWD_FAST_EXP2 0 #endif +#ifndef CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN +#define CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN 0 +#endif + #ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA #define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 #endif diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 8691622bb0..6398bf316e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -702,12 +702,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } const auto p = [&]() { +#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN + // For fp32 to fp16, + // impl::cast_tile_pk_fp16_fp32 would cause precision issue, + // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero. + return cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#else if constexpr(std::is_same_v) return impl::cast_tile_pk_fp16_fp32( tile_elementwise_in(p_compute_element_func, p_compute)); else return cast_tile( tile_elementwise_in(p_compute_element_func, p_compute)); +#endif }(); // STAGE 3, KV gemm diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7af3902dc5..ba788c7f1e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -653,12 +653,19 @@ struct BlockFmhaPipelineQRKSVSAsync } const auto p = [&]() { +#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN + // For fp32 to fp16, + // impl::cast_tile_pk_fp16_fp32 would cause precision issue, + // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero. + return cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); +#else if constexpr(std::is_same_v) return impl::cast_tile_pk_fp16_fp32( tile_elementwise_in(p_compute_element_func, p_compute)); else return cast_tile( tile_elementwise_in(p_compute_element_func, p_compute)); +#endif }(); // STAGE 3, KV gemm