Fix wrong uneven split checking logics

This commit is contained in:
PoYen, Chen
2024-08-19 08:14:45 +00:00
parent 3f0dab6a77
commit 8166aa58aa

View File

@@ -114,8 +114,11 @@ template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
}} else {{
kernel_runner<true>::run(s, a);