enable larger tile size; upgrade xor pattern

This commit is contained in:
aska-0096
2025-07-31 05:13:27 +00:00
parent 69890afc98
commit 75cba48682
4 changed files with 66 additions and 63 deletions

View File

@@ -38,8 +38,8 @@ K0_MAX_SUBMAX_MAP = {
}
SEQLENQ_MAP = {
"16" : "16",
"32" : "32",
# "16" : "16",
# "32" : "32",
# "64" : "64"
"128" : "128",
}
@@ -132,18 +132,18 @@ using trait_{F_idx} = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_
namespace {{
template <bool kHasUnevenSplits>
void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
}} else {{
instance<kHasUnevenSplits>::run(s, a);
}}
// instance<kHasUnevenSplits>::run(s, a);
//if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
// && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
// || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
// if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
// instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
// }} else {{
// instance<kHasUnevenSplits>::run(s, a);
// }}
//}} else {{
// instance<kHasUnevenSplits>::run(s, a);
//}}
instance<kHasUnevenSplits>::run(s, a);
}}
}} // anonymous namespace
@@ -152,20 +152,20 @@ void run_instance(const ck_tile::stream_config& s, fmha_fwd_decode_args a) {{
template<>
void fmha_fwd_decode_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_decode_args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
run_instance</*kHasUnevenSplits=*/true>(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
run_instance</*kHasUnevenSplits=*/false>(s, a);
}} else {{
run_instance</*kHasUnevenSplits=*/true>(s, a);
}}
}} else {{
run_instance</*kHasUnevenSplits=*/true>(s, a);
}}
// run_instance</*kHasUnevenSplits=*/true>(s, a);
//if constexpr({F_mode} == false) {{ // batch mode
// // we don't check every seqlen_k values for kvcache
// if (a.seqlen_k_ptr != nullptr) {{
// run_instance</*kHasUnevenSplits=*/true>(s, a);
// // make sure F_bn0 is divisible by F_bk1
// }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
// run_instance</*kHasUnevenSplits=*/false>(s, a);
// }} else {{
// run_instance</*kHasUnevenSplits=*/true>(s, a);
// }}
//}} else {{
// run_instance</*kHasUnevenSplits=*/true>(s, a);
//}}
run_instance</*kHasUnevenSplits=*/false>(s, a);
}}
template<>
@@ -658,16 +658,16 @@ class FmhaFwdSplitKVCombineKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'64': {
# Specialize for different SeqQ
'16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
'32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
},
# '64': {
# # Specialize for different SeqQ
# '16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '128': FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# },
'128': {
'16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
'32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128': FmhaFwdTileSize(128, 32, 128, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
# '32': FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128': FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
},
}
else: