diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 475631a885..72d5970bbf 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -878,6 +878,8 @@ def get_fwd_blobs( cond &= pipeline.F_qscale == "no" if not cond: continue + elif receipt == 700: + continue # TE does not use this API # fp32 only if receipt == 800 or receipt == 801: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8da6eed212..abb84a389b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1149,6 +1149,12 @@ def get_bwd_blobs( cond = dtype in ["fp16", "bf16"] if not cond: continue + # TransformerEngine integration + elif receipt == 700: + cond = dtype in ["fp16", "bf16"] + cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"] + if not cond: + continue # fp32 only, all variations if receipt == 800: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 741ef4062d..0003fce892 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1454,6 +1454,20 @@ def get_product(receipt: int) -> Product: return cond return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + # TransformerEngine integration + elif receipt == 700: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_lse == "t" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="TransformerEngine integration", rule=fit) elif receipt == 888: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c9bac50da1..f0396ed5eb 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -970,6 +970,8 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_squant == "f" if not cond: continue + elif receipt == 700: + continue # TE does not use this API # fp32 only if receipt == 800 or receipt == 801: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 1ac1f1c38a..7c7bddb345 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -745,6 +745,8 @@ def get_fwd_blobs( cond &= pipeline.F_squant == "f" if not cond: continue + elif receipt == 700: + continue # TE does not use this API # fp32 only if receipt == 800 or receipt == 801: diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index a5a2d08563..434e1cab76 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -139,7 +139,9 @@ if __name__ == "__main__": + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" - + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration", + + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration\n" + + " 700: Only generate instance for TransformerEngine integration (fwd + bwd, fp16/bf16 only,\n" + + " invariants: row vlayout, has_lse, no skip/sink/logits/qscale)", ) parser.add_argument(