[rocm-libraries] ROCm/rocm-libraries#6867 (commit 3cb0219)

Added custom FMHA codegen receipt for TransformerEngine
 (#6867)

## Motivation

TE uses AITER to build static MHA libraries, which ultimately rely on CK
kernels. We use the `600` receipt which generates more kernels than TE
truly needs. This bespoke receipt allows us to minimize the kernel
count, compile time, and memory footprint of our MHA library.

## Technical Details

Extended the receipt mechanism to include a custom `700` receipt for
TE's needs

## Test Plan

Test by building TE using the same receipt profile

## Test Result

Build validated in TE using a custom feature branches of AITER/CK to
temporarily apply the patch

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Meekail Zain
2026-05-14 14:34:03 +00:00
committed by assistant-librarian[bot]
parent 83566edb0f
commit d931e8703d
6 changed files with 29 additions and 1 deletions

View File

@@ -878,6 +878,8 @@ def get_fwd_blobs(
cond &= pipeline.F_qscale == "no"
if not cond:
continue
elif receipt == 700:
continue # TE does not use this API
# fp32 only
if receipt == 800 or receipt == 801:

View File

@@ -1149,6 +1149,12 @@ def get_bwd_blobs(
cond = dtype in ["fp16", "bf16"]
if not cond:
continue
# TransformerEngine integration
elif receipt == 700:
cond = dtype in ["fp16", "bf16"]
cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"]
if not cond:
continue
# fp32 only, all variations
if receipt == 800:

View File

@@ -1454,6 +1454,20 @@ def get_product(receipt: int) -> Product:
return cond
return Product(name="aiter::mha_fwd C++ api integration", rule=fit)
# TransformerEngine integration
elif receipt == 700:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
cond = problem_ctx.dtype in ["fp16", "bf16"]
cond &= kernel_ctx.pipeline.F_vlayout == "row"
cond &= kernel_ctx.pipeline.F_qscale == "no"
cond &= kernel_ctx.pipeline.F_lse == "t"
cond &= kernel_ctx.pipeline.F_skip == "f"
cond &= kernel_ctx.pipeline.F_sink == "f"
cond &= kernel_ctx.pipeline.F_logits == "f"
return cond
return Product(name="TransformerEngine integration", rule=fit)
elif receipt == 888:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:

View File

@@ -970,6 +970,8 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_squant == "f"
if not cond:
continue
elif receipt == 700:
continue # TE does not use this API
# fp32 only
if receipt == 800 or receipt == 801:

View File

@@ -745,6 +745,8 @@ def get_fwd_blobs(
cond &= pipeline.F_squant == "f"
if not cond:
continue
elif receipt == 700:
continue # TE does not use this API
# fp32 only
if receipt == 800 or receipt == 801: