fix group deterministic bugs

This commit is contained in:
danyao12
2024-07-11 17:34:59 +08:00
parent 8c967d76d1
commit 39fc3d4b2e
3 changed files with 6 additions and 10 deletions

View File

@@ -132,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim

View File

@@ -297,7 +297,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqlen_k_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_q,
args.nhead_stride_q,