Added TE-specialized receipt

This commit is contained in:
Meekail Zain
2026-04-21 15:24:25 +00:00
parent fdf4bb7fcc
commit e09e6a81f3
3 changed files with 23 additions and 1 deletions

View File

@@ -1145,6 +1145,12 @@ def get_bwd_blobs(
cond = dtype in ["fp16", "bf16"]
if not cond:
continue
# TransformerEngine integration
elif receipt == 700:
cond = dtype in ["fp16", "bf16"]
cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"]
if not cond:
continue
# fp32 only, all variations
if receipt == 800:

View File

@@ -1455,6 +1455,20 @@ def get_product(receipt: int) -> Product:
return cond
return Product(name="aiter::mha_fwd C++ api integration", rule=fit)
# TransformerEngine integration
elif receipt == 700:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
cond = problem_ctx.dtype in ["fp16", "bf16"]
cond &= kernel_ctx.pipeline.F_vlayout == "row"
cond &= kernel_ctx.pipeline.F_qscale == "no"
cond &= kernel_ctx.pipeline.F_lse == "t"
cond &= kernel_ctx.pipeline.F_skip == "f"
cond &= kernel_ctx.pipeline.F_sink == "f"
cond &= kernel_ctx.pipeline.F_logits == "f"
return cond
return Product(name="TransformerEngine integration", rule=fit)
elif receipt == 888:
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:

View File

@@ -139,7 +139,9 @@ if __name__ == "__main__":
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration\n"
+ " 700: Only generate instance for TransformerEngine integration (fwd + bwd, fp16/bf16 only,\n"
+ " invariants: row vlayout, has_lse, no skip/sink/logits/qscale)",
)
parser.add_argument(