mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[rocm-libraries] ROCm/rocm-libraries#6038 (commit d7041a2)
[CK_TILE] Restrict FMHA codegen to the kernel subset used by FlashAttention (#6038) ## Motivation Currently, the CK FlashAttention integration generates a broader FMHA kernel set than the FlashAttention wrappers can actually dispatch, which increases compile time without improving runtime coverage. ## Technical Details The FlashAttention CK wrappers do not use all logits/LSE variants emitted by the default FMHA codegen. The direct `fmha_fwd` path always uses softcap-disabled, LSE-enabled kernels, and the `fmha_fwd_splitkv` path only uses softcap-disabled kernels. This change trims codegen to that subset and stops generating the unused logits/LSE variants. This reduces the generated forward kernel set without changing `fmha_fwd_appendkv` or `fmha_bwd`. The reduced kernel set was validated by building and running the [FlashAttention](https://github.com/Dao-AILab/flash-attention) CK backend. Across targets, the total generated FMHA kernel count is reduced by: - `gfx942`: 29.3% - `gfx1100`: 33.7% - `gfx1201`: 31.3% ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> pytest test/test_flash_attn_ck.py from https://github.com/Dao-AILab/flash-attention ## Test Result all tests passed <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
144854dba1
commit
1dc35ff4ae
@@ -1304,7 +1304,23 @@ class Product:
|
||||
|
||||
def get_product(receipt: int) -> Product:
|
||||
# Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
if receipt == 2:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
cond = problem_ctx.dtype in ["fp16", "bf16"]
|
||||
cond &= kernel_ctx.pipeline.F_vlayout == "row"
|
||||
cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= kernel_ctx.pipeline.F_qscale == "no"
|
||||
cond &= kernel_ctx.pipeline.F_skip == "f"
|
||||
cond &= kernel_ctx.pipeline.F_sink == "f"
|
||||
# FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled.
|
||||
cond &= kernel_ctx.pipeline.F_logits == "f"
|
||||
cond &= kernel_ctx.pipeline.F_lse == "t"
|
||||
return cond
|
||||
|
||||
return Product(name="Flash attention integration", rule=fit)
|
||||
# Receipt 3 forward coverage used by CK library / smoke tests
|
||||
elif receipt == 3:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
cond = problem_ctx.dtype in ["fp16", "bf16"]
|
||||
|
||||
@@ -939,6 +939,8 @@ def get_fwd_splitkv_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
# FlashAttention splitkv paths use softcap-disabled kernels only.
|
||||
cond &= pipeline.F_logits == "f"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_sink == "f"
|
||||
if not cond:
|
||||
@@ -1142,4 +1144,7 @@ def list_blobs(
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n")
|
||||
f.write(
|
||||
(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix()
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user