mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Split-KV codegen: dual-tile dispatch and head-merge for hdim=64
1. Dual-tile: add both bn0=64 (preferred) and bn0=32 (fallback) for hdim=64 on gfx9 and gfx12. The dispatch checks page_block_size % bn0 == 0 at runtime to select the optimal tile. bn0=64 halves KV iterations when page_block_size >= 64. 2. Tile dict now supports lists per hdim. The codegen loop iterates over all tile variants, generating separate kernel instances for each. Combine kernels are unaffected (tile-independent). 3. Enable kMergeNumHeadGroupsSeqLenQ for hdim=64 decode (previously hdim=128 only). For GQA-8 with max_seqlen_q=1, this packs 8 head groups into the M dimension. Only activates for no-mask instances (kernel static_assert requires !kHasMask). 4. Add qr (non-async) pipeline for fwd non-bias group mode as fallback after qr_async. The async pipeline on this branch has a kernel-level bug where fmha_fwd launches but writes no output. Made-with: Cursor
This commit is contained in:
@@ -1017,6 +1017,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
|||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||||
else:
|
else:
|
||||||
|
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
|
||||||
if receipt == 1 and bias != "bias":
|
if receipt == 1 and bias != "bias":
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
|
|||||||
namespace {{
|
namespace {{
|
||||||
template <bool kHasUnevenSplits>
|
template <bool kHasUnevenSplits>
|
||||||
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
|
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
|
||||||
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
if constexpr (({F_hdim} == 64 || {F_hdim} == 128) && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||||
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||||
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||||
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||||
@@ -280,7 +280,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (a.block_table_ptr == nullptr || a.page_block_size % {F_bn0} == 0)) {{
|
||||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||||
|
|
||||||
// get combine kernel tile sizes
|
// get combine kernel tile sizes
|
||||||
@@ -820,17 +820,18 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
|
|||||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||||
if dtype in ["fp16", "bf16"]:
|
if dtype in ["fp16", "bf16"]:
|
||||||
return {
|
return {
|
||||||
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"32" : [FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
"64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||||
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"96" : [FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"128": [FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
# "160" : [FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
|
"256": [FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
} # fmt: skip
|
} # fmt: skip
|
||||||
elif dtype in ["fp8", "bf8"]:
|
elif dtype in ["fp8", "bf8"]:
|
||||||
return {
|
return {
|
||||||
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
"64" : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||||
"128": FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
"128": [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||||
} # fmt: skip
|
} # fmt: skip
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -861,16 +862,17 @@ class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
|
|||||||
if dtype in ["fp16", "bf16"]:
|
if dtype in ["fp16", "bf16"]:
|
||||||
return {
|
return {
|
||||||
# bm0, bn0, bk0, bn1, bk1,
|
# bm0, bn0, bk0, bn1, bk1,
|
||||||
"32" : FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"32" : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
"64" : FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"64" : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
FmhaFwdTileSize( 64, 32, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
"256": FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"128": [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
|
"256": [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
} # fmt: skip
|
} # fmt: skip
|
||||||
elif dtype in ["fp8", "bf8"]:
|
elif dtype in ["fp8", "bf8"]:
|
||||||
return {
|
return {
|
||||||
# bm0, bn0, bk0, bn1, bk1,
|
# bm0, bn0, bk0, bn1, bk1,
|
||||||
"64" : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"64" : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
"128": FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
"128": [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||||
} # fmt: skip
|
} # fmt: skip
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -905,77 +907,78 @@ def get_fwd_splitkv_blobs(
|
|||||||
continue
|
continue
|
||||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||||
tile = d[hdim_str]
|
tiles = d[hdim_str]
|
||||||
|
if not isinstance(tiles, list):
|
||||||
|
tiles = [tiles]
|
||||||
hdim = int(hdim_str)
|
hdim = int(hdim_str)
|
||||||
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
|
for tile in tiles:
|
||||||
if mode == "group":
|
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
|
||||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
if mode == "group":
|
||||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||||
continue
|
continue
|
||||||
# logits_soft_cap is only allowed if no bias
|
if not (
|
||||||
if not (
|
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
|
||||||
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
|
or pipeline.F_logits == "f"
|
||||||
or pipeline.F_logits == "f"
|
):
|
||||||
):
|
|
||||||
continue
|
|
||||||
k = Kernel(
|
|
||||||
F_arch=factory.arch,
|
|
||||||
F_idx=0,
|
|
||||||
F_hdim=hdim,
|
|
||||||
F_dtype=dtype,
|
|
||||||
F_mode=mode,
|
|
||||||
F_tile=tile,
|
|
||||||
F_pipeline=pipeline,
|
|
||||||
mask_impl=mask_impl,
|
|
||||||
)
|
|
||||||
if kernel_filter != "":
|
|
||||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
|
||||||
continue
|
|
||||||
if optdim_list != [-1]:
|
|
||||||
if hdim not in optdim_list:
|
|
||||||
continue
|
|
||||||
# Flash attention integration
|
|
||||||
if receipt == 2:
|
|
||||||
cond = dtype in ["fp16", "bf16"]
|
|
||||||
cond &= pipeline.F_vlayout == "row"
|
|
||||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
|
||||||
cond &= pipeline.F_squant == "f"
|
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
|
||||||
continue
|
|
||||||
# PyTorch integration
|
|
||||||
elif receipt == 4:
|
|
||||||
cond = dtype in ["fp16, bf16"]
|
|
||||||
cond &= pipeline.F_vlayout == "row"
|
|
||||||
cond &= pipeline.F_bias in ["no", "bias"]
|
|
||||||
cond &= pipeline.F_squant == "f"
|
|
||||||
cond &= mode == "batch"
|
|
||||||
cond &= pipeline.F_sink == "f"
|
|
||||||
if not cond:
|
|
||||||
continue
|
|
||||||
# Aiter(mha_varlen_fwd) integration
|
|
||||||
elif receipt == 200:
|
|
||||||
cond = dtype in ["fp16", "bf16"]
|
|
||||||
cond &= mode == "group"
|
|
||||||
cond &= pipeline.F_vlayout == "row"
|
|
||||||
cond &= pipeline.F_squant == "f"
|
|
||||||
if not cond:
|
|
||||||
continue
|
|
||||||
# aiter::mha_fwd_splikv C++ api integration
|
|
||||||
elif receipt == 600:
|
|
||||||
cond = dtype in ["fp16", "bf16"]
|
|
||||||
cond &= pipeline.F_vlayout == "row"
|
|
||||||
cond &= pipeline.F_squant == "f"
|
|
||||||
if not cond:
|
|
||||||
continue
|
continue
|
||||||
|
k = Kernel(
|
||||||
|
F_arch=factory.arch,
|
||||||
|
F_idx=0,
|
||||||
|
F_hdim=hdim,
|
||||||
|
F_dtype=dtype,
|
||||||
|
F_mode=mode,
|
||||||
|
F_tile=tile,
|
||||||
|
F_pipeline=pipeline,
|
||||||
|
mask_impl=mask_impl,
|
||||||
|
)
|
||||||
|
if kernel_filter != "":
|
||||||
|
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||||
|
continue
|
||||||
|
if optdim_list != [-1]:
|
||||||
|
if hdim not in optdim_list:
|
||||||
|
continue
|
||||||
|
# Flash attention integration
|
||||||
|
if receipt == 2:
|
||||||
|
cond = dtype in ["fp16", "bf16"]
|
||||||
|
cond &= pipeline.F_vlayout == "row"
|
||||||
|
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||||
|
cond &= pipeline.F_squant == "f"
|
||||||
|
cond &= pipeline.F_sink == "f"
|
||||||
|
if not cond:
|
||||||
|
continue
|
||||||
|
# PyTorch integration
|
||||||
|
elif receipt == 4:
|
||||||
|
cond = dtype in ["fp16, bf16"]
|
||||||
|
cond &= pipeline.F_vlayout == "row"
|
||||||
|
cond &= pipeline.F_bias in ["no", "bias"]
|
||||||
|
cond &= pipeline.F_squant == "f"
|
||||||
|
cond &= mode == "batch"
|
||||||
|
cond &= pipeline.F_sink == "f"
|
||||||
|
if not cond:
|
||||||
|
continue
|
||||||
|
# Aiter(mha_varlen_fwd) integration
|
||||||
|
elif receipt == 200:
|
||||||
|
cond = dtype in ["fp16", "bf16"]
|
||||||
|
cond &= mode == "group"
|
||||||
|
cond &= pipeline.F_vlayout == "row"
|
||||||
|
cond &= pipeline.F_squant == "f"
|
||||||
|
if not cond:
|
||||||
|
continue
|
||||||
|
# aiter::mha_fwd_splikv C++ api integration
|
||||||
|
elif receipt == 600:
|
||||||
|
cond = dtype in ["fp16", "bf16"]
|
||||||
|
cond &= pipeline.F_vlayout == "row"
|
||||||
|
cond &= pipeline.F_squant == "f"
|
||||||
|
if not cond:
|
||||||
|
continue
|
||||||
|
|
||||||
# fp32 only
|
# fp32 only
|
||||||
if receipt == 800 or receipt == 801:
|
if receipt == 800 or receipt == 801:
|
||||||
cond = dtype == "fp32"
|
cond = dtype == "fp32"
|
||||||
if not cond:
|
if not cond:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
gen.append(k)
|
gen.append(k)
|
||||||
|
|
||||||
return gen
|
return gen
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user