From 39fc3d4b2e080e5debe464802d221deb2938f9fe Mon Sep 17 00:00:00 2001 From: danyao12 Date: Thu, 11 Jul 2024 17:34:59 +0800 Subject: [PATCH] fix group deterministic bugs --- example/ck_tile/01_fmha/fmha_bwd.cpp | 5 ----- example/ck_tile/01_fmha/fmha_bwd.hpp | 2 +- include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp | 9 +++++---- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index f5aab919d5..2f96ccb4fd 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -132,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); if(hdim_v < 0) hdim_v = hdim_q; - if(hdim_q % 2 != 0 || hdim_v % 2 != 0) - { - std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl; - return false; - } bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index f32557b40c..fecc85aaaf 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -297,7 +297,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, args.dq_ptr, args.seqstart_q_ptr, - args.seqlen_k_ptr, + args.seqstart_k_ptr, args.hdim_q, args.stride_q, args.nhead_stride_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index dfca858024..df30e8b163 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1375,7 +1375,7 @@ struct FmhaBwdConvertQGradKernel FmhaBwdConvertQGradEmptyKargs<0>> { const int32_t* seqstart_q_ptr; - const int32_t* seqlen_k_ptr; + const int32_t* seqstart_k_ptr; }; using Kargs = std::conditional_t(seqstart_q_ptr), - reinterpret_cast(seqlen_k_ptr)}; + reinterpret_cast(seqstart_k_ptr)}; if constexpr(kIsDeterministic) { @@ -1463,7 +1463,8 @@ struct FmhaBwdConvertQGradKernel // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0)