mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
[rocm-libraries] ROCm/rocm-libraries#6867 (commit 3cb0219)
Added custom FMHA codegen receipt for TransformerEngine (#6867) ## Motivation TE uses AITER to build static MHA libraries, which ultimately rely on CK kernels. We use the `600` receipt which generates more kernels than TE truly needs. This bespoke receipt allows us to minimize the kernel count, compile time, and memory footprint of our MHA library. ## Technical Details Extended the receipt mechanism to include a custom `700` receipt for TE's needs ## Test Plan Test by building TE using the same receipt profile ## Test Result Build validated in TE using a custom feature branches of AITER/CK to temporarily apply the patch ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
83566edb0f
commit
d931e8703d
@@ -878,6 +878,8 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 700:
|
||||
continue # TE does not use this API
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
|
||||
@@ -1149,6 +1149,12 @@ def get_bwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
if not cond:
|
||||
continue
|
||||
# TransformerEngine integration
|
||||
elif receipt == 700:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"]
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only, all variations
|
||||
if receipt == 800:
|
||||
|
||||
@@ -1454,6 +1454,20 @@ def get_product(receipt: int) -> Product:
|
||||
return cond
|
||||
|
||||
return Product(name="aiter::mha_fwd C++ api integration", rule=fit)
|
||||
# TransformerEngine integration
|
||||
elif receipt == 700:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
cond = problem_ctx.dtype in ["fp16", "bf16"]
|
||||
cond &= kernel_ctx.pipeline.F_vlayout == "row"
|
||||
cond &= kernel_ctx.pipeline.F_qscale == "no"
|
||||
cond &= kernel_ctx.pipeline.F_lse == "t"
|
||||
cond &= kernel_ctx.pipeline.F_skip == "f"
|
||||
cond &= kernel_ctx.pipeline.F_sink == "f"
|
||||
cond &= kernel_ctx.pipeline.F_logits == "f"
|
||||
return cond
|
||||
|
||||
return Product(name="TransformerEngine integration", rule=fit)
|
||||
elif receipt == 888:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
|
||||
@@ -970,6 +970,8 @@ def get_fwd_splitkv_blobs(
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 700:
|
||||
continue # TE does not use this API
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
|
||||
@@ -745,6 +745,8 @@ def get_fwd_blobs(
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 700:
|
||||
continue # TE does not use this API
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
|
||||
@@ -139,7 +139,9 @@ if __name__ == "__main__":
|
||||
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
|
||||
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
|
||||
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration\n"
|
||||
+ " 700: Only generate instance for TransformerEngine integration (fwd + bwd, fp16/bf16 only,\n"
|
||||
+ " invariants: row vlayout, has_lse, no skip/sink/logits/qscale)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user