From 53704854592cf1837eec3df369be46b460195425 Mon Sep 17 00:00:00 2001 From: Hosang Yoon <156028780+hyoon1@users.noreply.github.com> Date: Thu, 2 Apr 2026 20:16:32 -0400 Subject: [PATCH] [CK_TILE] Restrict FMHA codegen to the kernel subset used by FlashAttention (#6038) ## Motivation Currently, the CK FlashAttention integration generates a broader FMHA kernel set than the FlashAttention wrappers can actually dispatch, which increases compile time without improving runtime coverage. ## Technical Details The FlashAttention CK wrappers do not use all logits/LSE variants emitted by the default FMHA codegen. The direct `fmha_fwd` path always uses softcap-disabled, LSE-enabled kernels, and the `fmha_fwd_splitkv` path only uses softcap-disabled kernels. This change trims codegen to that subset and stops generating the unused logits/LSE variants. This reduces the generated forward kernel set without changing `fmha_fwd_appendkv` or `fmha_bwd`. The reduced kernel set was validated by building and running the [FlashAttention](https://github.com/Dao-AILab/flash-attention) CK backend. Across targets, the total generated FMHA kernel count is reduced by: - `gfx942`: 29.3% - `gfx1100`: 33.7% - `gfx1201`: 31.3% ## Test Plan pytest test/test_flash_attn_ck.py from https://github.com/Dao-AILab/flash-attention ## Test Result all tests passed ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 18 +++++++++++++++++- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 7 ++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d6f158116e..a5fffb5159 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1304,7 +1304,23 @@ class Product: def get_product(receipt: int) -> Product: # Flash attention integration - if receipt in (2, 3): + if receipt == 2: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" + # FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled. + cond &= kernel_ctx.pipeline.F_logits == "f" + cond &= kernel_ctx.pipeline.F_lse == "t" + return cond + + return Product(name="Flash attention integration", rule=fit) + # Receipt 3 forward coverage used by CK library / smoke tests + elif receipt == 3: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond = problem_ctx.dtype in ["fp16", "bf16"] diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index e0ccde8a6b..c9bac50da1 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -939,6 +939,8 @@ def get_fwd_splitkv_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] + # FlashAttention splitkv paths use softcap-disabled kernels only. + cond &= pipeline.F_logits == "f" cond &= pipeline.F_squant == "f" cond &= pipeline.F_sink == "f" if not cond: @@ -1142,4 +1144,7 @@ def list_blobs( ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n") + f.write( + (file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + + "\n" + )