Fix uneven split checking logic

This commit is contained in:
PoYen, Chen
2024-08-06 01:17:14 +00:00
parent 77dac7775c
commit 8779716403

View File

@@ -108,7 +108,8 @@ template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0
&& a.seqlen_k % (a.num_splits * {F_bk1}) == 0) {{
kernel_runner<false>::run(s, a);
}} else {{
kernel_runner<true>::run(s, a);
@@ -526,7 +527,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']:
# splitkv kernel donot support dropout
for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]):
if hdim == 256:
if hdim == 256 or hdim == 32 or hdim == 64 or hdim == 128:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))