From d931e8703dfa66027ca09da3175b320f6053b128 Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Thu, 14 May 2026 14:34:03 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6867 (commit 3cb0219) Added custom FMHA codegen receipt for TransformerEngine (#6867) ## Motivation TE uses AITER to build static MHA libraries, which ultimately rely on CK kernels. We use the `600` receipt which generates more kernels than TE truly needs. This bespoke receipt allows us to minimize the kernel count, compile time, and memory footprint of our MHA library. ## Technical Details Extended the receipt mechanism to include a custom `700` receipt for TE's needs ## Test Plan Test by building TE using the same receipt profile ## Test Result Build validated in TE using a custom feature branches of AITER/CK to temporarily apply the patch ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 2 ++ example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 6 ++++++ example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 14 ++++++++++++++ .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 2 ++ .../01_fmha/codegen/ops/fmha_pagedkv_prefill.py | 2 ++ example/ck_tile/01_fmha/generate.py | 4 +++- 6 files changed, 29 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 475631a885..72d5970bbf 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -878,6 +878,8 @@ def get_fwd_blobs( cond &= pipeline.F_qscale == "no" if not cond: continue + elif receipt == 700: + continue # TE does not use this API # fp32 only if receipt == 800 or receipt == 801: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8da6eed212..abb84a389b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1149,6 +1149,12 @@ def get_bwd_blobs( cond = dtype in ["fp16", "bf16"] if not cond: continue + # TransformerEngine integration + elif receipt == 700: + cond = dtype in ["fp16", "bf16"] + cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"] + if not cond: + continue # fp32 only, all variations if receipt == 800: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 741ef4062d..0003fce892 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1454,6 +1454,20 @@ def get_product(receipt: int) -> Product: return cond return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + # TransformerEngine integration + elif receipt == 700: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_lse == "t" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="TransformerEngine integration", rule=fit) elif receipt == 888: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c9bac50da1..f0396ed5eb 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -970,6 +970,8 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_squant == "f" if not cond: continue + elif receipt == 700: + continue # TE does not use this API # fp32 only if receipt == 800 or receipt == 801: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 1ac1f1c38a..7c7bddb345 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -745,6 +745,8 @@ def get_fwd_blobs( cond &= pipeline.F_squant == "f" if not cond: continue + elif receipt == 700: + continue # TE does not use this API # fp32 only if receipt == 800 or receipt == 801: diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index a5a2d08563..434e1cab76 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -139,7 +139,9 @@ if __name__ == "__main__": + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" - + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration", + + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration\n" + + " 700: Only generate instance for TransformerEngine integration (fwd + bwd, fp16/bf16 only,\n" + + " invariants: row vlayout, has_lse, no skip/sink/logits/qscale)", ) parser.add_argument(