mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK_TILE] Restrict FMHA codegen to the kernel subset used by FlashAttention (#6038) ## Motivation Currently, the CK FlashAttention integration generates a broader FMHA kernel set than the FlashAttention wrappers can actually dispatch, which increases compile time without improving runtime coverage. ## Technical Details The FlashAttention CK wrappers do not use all logits/LSE variants emitted by the default FMHA codegen. The direct `fmha_fwd` path always uses softcap-disabled, LSE-enabled kernels, and the `fmha_fwd_splitkv` path only uses softcap-disabled kernels. This change trims codegen to that subset and stops generating the unused logits/LSE variants. This reduces the generated forward kernel set without changing `fmha_fwd_appendkv` or `fmha_bwd`. The reduced kernel set was validated by building and running the [FlashAttention](https://github.com/Dao-AILab/flash-attention) CK backend. Across targets, the total generated FMHA kernel count is reduced by: - `gfx942`: 29.3% - `gfx1100`: 33.7% - `gfx1201`: 31.3% ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> pytest test/test_flash_attn_ck.py from https://github.com/Dao-AILab/flash-attention ## Test Result all tests passed <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.