FMHA BWD Avoid SetZero (#2799)

[ROCm/composable_kernel commit: ad259eeae2]
This commit is contained in:
Yi DING
2025-09-23 14:37:48 +08:00
committed by GitHub
parent 0b149c8695
commit 01c6567d4c
3 changed files with 10 additions and 4 deletions

View File

@@ -763,15 +763,21 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
ck_tile::FillConstant<AccDataType>{ck_tile::numeric<AccDataType>::infinity()}(dq_acc_host);
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
dq_acc_buf.ToDevice(dq_acc_host.data());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
dbias_buf.SetZero();
dq_acc_buf.SetZero();
// non-deterministic kernels use atomic add to write dq
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases thus we need to zero out it first
if(!deterministic || mask.type == mask_enum::no_mask)
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
fmha_bwd(fmha_traits, fmha_args, stream_config_v);

View File

@@ -16,6 +16,6 @@ const auto HDimValues =
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
constexpr std::string init_method = "uf";
constexpr auto init_method = "uf";
#include "test_fmha_bwd.inc"

View File

@@ -16,6 +16,6 @@ const auto HDimValues =
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
constexpr std::string init_method = "uf";
constexpr auto init_method = "uf";
#include "test_fmha_bwd.inc"