diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 64643f157e..613f3b6e66 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -108,7 +108,8 @@ template<> void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ if constexpr({F_mode} == false) {{ // batch mode - if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0 + && a.seqlen_k % (a.num_splits * {F_bk1}) == 0) {{ kernel_runner::run(s, a); }} else {{ kernel_runner::run(s, a); @@ -526,7 +527,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if dtype in ['fp16', 'bf16']: # splitkv kernel donot support dropout for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]): - if hdim == 256: + if hdim == 256 or hdim == 32 or hdim == 64 or hdim == 128: # if True: pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))