diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index fcbf6dfc2d..23c73e5f43 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -983,6 +983,9 @@ struct FmhaBwdDQDKDVKernel long_index_t batch_offset_dk = 0; long_index_t batch_offset_dv = 0; long_index_t batch_offset_dbias = 0; + // dq_acc per-nhead stride uses padded seqlen_q in group mode; equals kargs.seqlen_q + // in batch mode. See FmhaBwdWorkspaceManager doc. + index_t physical_seqlen_q = kargs.seqlen_q; if constexpr(kIsGroupMode) { @@ -990,6 +993,9 @@ struct FmhaBwdDQDKDVKernel const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + physical_seqlen_q = + static_cast(kargs.seqstart_q_ptr[i_batch + 1] - query_start); + batch_offset_q = query_start * kargs.stride_q; batch_offset_k = key_start * kargs.stride_k; batch_offset_v = key_start * kargs.stride_v; @@ -1030,10 +1036,6 @@ struct FmhaBwdDQDKDVKernel } else { - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - const ck_tile::index_t physical_seqlen_q = - adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q; } @@ -1212,12 +1214,13 @@ struct FmhaBwdDQDKDVKernel else if constexpr(!kIsDeterministic) { return batch_offset_dq_acc + - static_cast(i_nhead_) * kargs.seqlen_q * kargs.hdim_q; + static_cast(i_nhead_) * physical_seqlen_q * kargs.hdim_q; } else { - const long_index_t split_stride = kargs.seqlen_q * kargs.hdim_q; - const auto nsplits = [&]() { + const long_index_t split_stride = + static_cast(physical_seqlen_q) * kargs.hdim_q; + const auto nsplits = [&]() { if constexpr(!kIsGroupMode) return n_splits; else