diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 3a5b5b4603..d861b351d4 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -763,15 +763,21 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_acc_host); dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); + dq_acc_buf.ToDevice(dq_acc_host.data()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); - dq_buf.SetZero(); dbias_buf.SetZero(); - dq_acc_buf.SetZero(); + + // non-deterministic kernels use atomic add to write dq + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases thus we need to zero out it first + if(!deterministic || mask.type == mask_enum::no_mask) + dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); diff --git a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp index cd143e8e83..077e45a10d 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp index 4bb1e04ad0..86621b0494 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc"