fix group deterministic bugs

This commit is contained in:
danyao12
2024-07-11 17:34:59 +08:00
parent 8c967d76d1
commit 39fc3d4b2e
3 changed files with 6 additions and 10 deletions

View File

@@ -1375,7 +1375,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs<0>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqlen_k_ptr;
const int32_t* seqstart_k_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode,
@@ -1411,7 +1411,7 @@ struct FmhaBwdConvertQGradKernel
MakeKargs(const void* dq_acc_ptr,
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqlen_k_ptr,
const void* seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t nhead_stride_dq,
@@ -1426,7 +1426,7 @@ struct FmhaBwdConvertQGradKernel
nhead_stride_dq},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
if constexpr(kIsDeterministic)
{
@@ -1463,7 +1463,8 @@ struct FmhaBwdConvertQGradKernel
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)