Add decode tiles (b16x32, b32x32) to pagedkv_prefill codegen with max_seqlen_q dispatch

Made-with: Cursor
This commit is contained in:
root
2026-04-01 18:30:06 +00:00
parent 65a3f88ad8
commit c5600bc8ae

View File

@@ -131,7 +131,7 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
""" """
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) && FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (a.max_seqlen_q <= {F_bm0} || {F_bm0} >= 128)) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>; using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a); return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
}} }}
@@ -581,7 +581,9 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
if dtype in ["fp16", "bf16"]: if dtype in ["fp16", "bf16"]:
return { return {
# "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), # "32": FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
"64": FmhaFwdTileSize(128, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), "64": [FmhaFwdTileSize( 16, 32, 32, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 32, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), # "96": FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), "128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), # "192": FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
@@ -667,93 +669,87 @@ def get_fwd_blobs(
if d is None: if d is None:
continue continue
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str] tiles = d[hdim_str]
if not isinstance(tiles, list):
tiles = [tiles]
hdim = int(hdim_str) hdim = int(hdim_str)
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): for tile in tiles:
# if pipeline.F_pagedkv == "f": for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
# continue if mode == "group":
if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t": continue
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not if hdim == 192 and tile.F_bn1 == 128:
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training # NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != "no" or pipeline.F_lse == "t": if pipeline.F_bias != "no" or pipeline.F_lse == "t":
continue continue
# logits_soft_cap is only allowed if no bias if not (
if not ( (pipeline.F_logits == "t" and pipeline.F_bias == "no")
(pipeline.F_logits == "t" and pipeline.F_bias == "no") or pipeline.F_logits == "f"
or pipeline.F_logits == "f" ):
):
continue
k = FmhaFwdKernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue continue
k = FmhaFwdKernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
if receipt in (2, 3):
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
elif receipt == 4:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_fwd) integration # Aiter(mha_fwd) integration
elif receipt == 100: elif receipt == 100:
cond = dtype in ["fp16", "bf16"] cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch" cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f" cond &= pipeline.F_squant == "f"
if not cond: if not cond:
continue continue
# Aiter(mha_varlen_fwd) integration elif receipt == 200:
elif receipt == 200: cond = dtype in ["fp16", "bf16"]
cond = dtype in ["fp16", "bf16"] cond &= mode == "group"
cond &= mode == "group" cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_squant == "f"
cond &= pipeline.F_squant == "f" if not cond:
if not cond: continue
continue elif receipt == 600:
# aiter::mha_fwd C++ api integration cond = dtype in ["fp16", "bf16"]
elif receipt == 600: cond &= pipeline.F_vlayout == "row"
cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_squant == "f"
cond &= pipeline.F_vlayout == "row" if not cond:
cond &= pipeline.F_squant == "f" continue
if not cond:
continue
# fp32 only if receipt == 800 or receipt == 801:
if receipt == 800 or receipt == 801: cond = dtype == "fp32"
cond = dtype == "fp32" if not cond:
if not cond: continue
continue
api_pool.register_traits(k.api_trait()) api_pool.register_traits(k.api_trait())
gen.append(k) gen.append(k)
return (api_pool, gen) return (api_pool, gen)