[rocm-libraries] ROCm/rocm-libraries#6038 (commit d7041a2)

[CK_TILE] Restrict FMHA codegen to the kernel subset used by
 FlashAttention (#6038)

## Motivation

Currently, the CK FlashAttention integration generates a broader FMHA
kernel set than the FlashAttention wrappers can actually dispatch, which
increases compile time without improving runtime coverage.

## Technical Details

The FlashAttention CK wrappers do not use all logits/LSE variants
emitted by the default FMHA codegen. The direct `fmha_fwd` path always
uses softcap-disabled, LSE-enabled kernels, and the `fmha_fwd_splitkv`
path only uses softcap-disabled kernels. This change trims codegen to
that subset and stops generating the unused logits/LSE variants.

This reduces the generated forward kernel set without changing
`fmha_fwd_appendkv` or `fmha_bwd`. The reduced kernel set was validated
by building and running the
[FlashAttention](https://github.com/Dao-AILab/flash-attention) CK
backend.

  Across targets, the total generated FMHA kernel count is reduced by:
  - `gfx942`: 29.3%
  - `gfx1100`: 33.7%
  - `gfx1201`: 31.3%

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->
pytest test/test_flash_attn_ck.py from
https://github.com/Dao-AILab/flash-attention

## Test Result
all tests passed
<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Hosang Yoon
2026-04-03 00:18:21 +00:00
committed by assistant-librarian[bot]
parent 144854dba1
commit 1dc35ff4ae
2 changed files with 23 additions and 2 deletions

View File

@@ -1304,7 +1304,23 @@ class Product:
def get_product(receipt: int) -> Product:
# Flash attention integration
if receipt in (2, 3):
if receipt == 2:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
cond = problem_ctx.dtype in ["fp16", "bf16"]
cond &= kernel_ctx.pipeline.F_vlayout == "row"
cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"]
cond &= kernel_ctx.pipeline.F_qscale == "no"
cond &= kernel_ctx.pipeline.F_skip == "f"
cond &= kernel_ctx.pipeline.F_sink == "f"
# FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled.
cond &= kernel_ctx.pipeline.F_logits == "f"
cond &= kernel_ctx.pipeline.F_lse == "t"
return cond
return Product(name="Flash attention integration", rule=fit)
# Receipt 3 forward coverage used by CK library / smoke tests
elif receipt == 3:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
cond = problem_ctx.dtype in ["fp16", "bf16"]

View File

@@ -939,6 +939,8 @@ def get_fwd_splitkv_blobs(
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
# FlashAttention splitkv paths use softcap-disabled kernels only.
cond &= pipeline.F_logits == "f"
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
@@ -1142,4 +1144,7 @@ def list_blobs(
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n")
f.write(
(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix()
+ "\n"
)