[CK_TILE] Fix Int32 Overflow in Deterministic FMHA BWD (#3615)

This commit is contained in:
Yi DING
2026-01-21 09:54:46 +08:00
committed by GitHub
parent d5ae81b292
commit fcc9372c00
3 changed files with 20 additions and 20 deletions

View File

@@ -189,7 +189,7 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::long_index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
@@ -202,7 +202,7 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::long_index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;

View File

@@ -287,9 +287,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
if(init_method == "ui" || init_method == "0")
{
@@ -433,6 +431,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
@@ -444,6 +443,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
const auto nhead_stride_dq_acc =
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
@@ -456,8 +457,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
@@ -513,7 +513,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
hdim_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
@@ -526,7 +526,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_dq_acc,
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
@@ -539,7 +539,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_dq_acc,
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,