diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 7105f1aa5c..9ce7f543d0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1145,6 +1145,12 @@ def get_bwd_blobs( cond = dtype in ["fp16", "bf16"] if not cond: continue + # TransformerEngine integration + elif receipt == 700: + cond = dtype in ["fp16", "bf16"] + cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"] + if not cond: + continue # fp32 only, all variations if receipt == 800: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c64a19104e..ffa3403e85 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1455,6 +1455,20 @@ def get_product(receipt: int) -> Product: return cond return Product(name="aiter::mha_fwd C++ api integration", rule=fit) + # TransformerEngine integration + elif receipt == 700: + + def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: + cond = problem_ctx.dtype in ["fp16", "bf16"] + cond &= kernel_ctx.pipeline.F_vlayout == "row" + cond &= kernel_ctx.pipeline.F_qscale == "no" + cond &= kernel_ctx.pipeline.F_lse == "t" + cond &= kernel_ctx.pipeline.F_skip == "f" + cond &= kernel_ctx.pipeline.F_sink == "f" + cond &= kernel_ctx.pipeline.F_logits == "f" + return cond + + return Product(name="TransformerEngine integration", rule=fit) elif receipt == 888: def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index a5a2d08563..434e1cab76 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -139,7 +139,9 @@ if __name__ == "__main__": + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" - + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration", + + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration\n" + + " 700: Only generate instance for TransformerEngine integration (fwd + bwd, fp16/bf16 only,\n" + + " invariants: row vlayout, has_lse, no skip/sink/logits/qscale)", ) parser.add_argument(