mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Fix Int32 Overflow in Deterministic FMHA BWD (#3615)
This commit is contained in:
@@ -189,7 +189,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::long_index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
@@ -202,7 +202,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::long_index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
|
||||
@@ -287,9 +287,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
i_perm
|
||||
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
|
||||
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
|
||||
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
@@ -433,6 +431,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
|
||||
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
@@ -444,6 +443,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
const auto nhead_stride_dq_acc =
|
||||
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
@@ -456,8 +457,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;
|
||||
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
@@ -513,7 +513,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_q, // stride_dq_acc
|
||||
hdim_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
@@ -526,7 +526,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_q, // nhead_stride_dq_acc
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
@@ -539,7 +539,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_q, // batch_stride_dq_acc
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
|
||||
Reference in New Issue
Block a user