Split-KV codegen: dual-tile dispatch and head-merge for hdim=64

1. Dual-tile: add both bn0=64 (preferred) and bn0=32 (fallback) for
   hdim=64 on gfx9 and gfx12. The dispatch checks page_block_size %
   bn0 == 0 at runtime to select the optimal tile. bn0=64 halves KV
   iterations when page_block_size >= 64.

2. Tile dict now supports lists per hdim. The codegen loop iterates
   over all tile variants, generating separate kernel instances for
   each. Combine kernels are unaffected (tile-independent).

3. Enable kMergeNumHeadGroupsSeqLenQ for hdim=64 decode (previously
   hdim=128 only). For GQA-8 with max_seqlen_q=1, this packs 8 head
   groups into the M dimension. Only activates for no-mask instances
   (kernel static_assert requires !kHasMask).

4. Add qr (non-async) pipeline for fwd non-bias group mode as
   fallback after qr_async. The async pipeline on this branch has a
   kernel-level bug where fmha_fwd launches but writes no output.

Made-with: Cursor
This commit is contained in:
root
2026-04-01 15:03:39 +00:00
parent 6729989b97
commit cb6fb2802d
2 changed files with 87 additions and 83 deletions

View File

@@ -1017,6 +1017,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
else:
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
if receipt == 1 and bias != "bias":

View File

@@ -127,7 +127,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
namespace {{
template <bool kHasUnevenSplits>
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
if constexpr (({F_hdim} == 64 || {F_hdim} == 128) && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
@@ -280,7 +280,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (a.block_table_ptr == nullptr || a.page_block_size % {F_bn0} == 0)) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
@@ -820,17 +820,18 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"32" : [FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"96" : [FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"128": [FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
# "160" : [FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"256": [FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
elif dtype in ["fp8", "bf8"]:
return {
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"64" : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
"128": [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
} # fmt: skip
else:
return None
@@ -861,16 +862,17 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
if dtype in ["fp16", "bf16"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"32" : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"128": [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"256": [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
elif dtype in ["fp8", "bf8"]:
return {
# bm0, bn0, bk0, bn1, bk1,
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
"128": [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
} # fmt: skip
else:
return None
@@ -905,77 +907,78 @@ def get_fwd_splitkv_blobs(
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
tiles = d[hdim_str]
if not isinstance(tiles, list):
tiles = [tiles]
hdim = int(hdim_str)
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
# logits_soft_cap is only allowed if no bias
if not (
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
or pipeline.F_logits == "f"
):
continue
k = Kernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16, bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
for tile in tiles:
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
continue
if not (
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
or pipeline.F_logits == "f"
):
continue
k = Kernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16, bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
gen.append(k)
gen.append(k)
return gen