mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Fix uneven split checking logic
This commit is contained in:
@@ -108,7 +108,8 @@ template<>
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0
|
||||
&& a.seqlen_k % (a.num_splits * {F_bk1}) == 0) {{
|
||||
kernel_runner<false>::run(s, a);
|
||||
}} else {{
|
||||
kernel_runner<true>::run(s, a);
|
||||
@@ -526,7 +527,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
# splitkv kernel donot support dropout
|
||||
for mask, bias, lse, dropout, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
if hdim == 256 or hdim == 32 or hdim == 64 or hdim == 128:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, pagedkv, mask))
|
||||
|
||||
Reference in New Issue
Block a user