diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index bace93b70e..507b65ef9d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1067,6 +1067,13 @@ struct FmhaFwdKernel const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + if constexpr(kIsChunkedPrefill) { if(kargs.seqlen_q <= kargs.min_seqlen_q) @@ -1075,13 +1082,6 @@ struct FmhaFwdKernel } } - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - if(kargs.seqlen_k_ptr != nullptr) { kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];