From bfa145c4180ceff75a32a36b42556db9a407897f Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 23 Sep 2025 14:37:48 +0800 Subject: [PATCH] FMHA BWD Avoid SetZero (#2799) [ROCm/composable_kernel commit: ad259eeae2d2533c63ca16305acf85d7d200b833] --- example/ck_tile/01_fmha/fmha_bwd_runner.hpp | 10 ++++++++-- test/ck_tile/fmha/test_fmha_bwd_bf16.cpp | 2 +- test/ck_tile/fmha/test_fmha_bwd_fp16.cpp | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 3a5b5b4603..d861b351d4 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -763,15 +763,21 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_acc_host); dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); + dq_acc_buf.ToDevice(dq_acc_host.data()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); - dq_buf.SetZero(); dbias_buf.SetZero(); - dq_acc_buf.SetZero(); + + // non-deterministic kernels use atomic add to write dq + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases thus we need to zero out it first + if(!deterministic || mask.type == mask_enum::no_mask) + dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); diff --git a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp index cd143e8e83..077e45a10d 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp index 4bb1e04ad0..86621b0494 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc"