mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix group deterministic bugs
This commit is contained in:
@@ -132,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
|
||||
{
|
||||
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
|
||||
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
|
||||
|
||||
@@ -297,7 +297,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_q,
|
||||
args.nhead_stride_q,
|
||||
|
||||
@@ -1375,7 +1375,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
FmhaBwdConvertQGradEmptyKargs<0>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode,
|
||||
@@ -1411,7 +1411,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
MakeKargs(const void* dq_acc_ptr,
|
||||
void* dq_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t stride_dq,
|
||||
ck_tile::index_t nhead_stride_dq,
|
||||
@@ -1426,7 +1426,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
nhead_stride_dq},
|
||||
{},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
@@ -1463,7 +1463,8 @@ struct FmhaBwdConvertQGradKernel
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
|
||||
Reference in New Issue
Block a user