mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Added TE-specialized receipt
This commit is contained in:
@@ -1145,6 +1145,12 @@ def get_bwd_blobs(
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
if not cond:
|
||||
continue
|
||||
# TransformerEngine integration
|
||||
elif receipt == 700:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= dropout in ["no", "dropout_wg32", "dropout_wg16"]
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only, all variations
|
||||
if receipt == 800:
|
||||
|
||||
@@ -1455,6 +1455,20 @@ def get_product(receipt: int) -> Product:
|
||||
return cond
|
||||
|
||||
return Product(name="aiter::mha_fwd C++ api integration", rule=fit)
|
||||
# TransformerEngine integration
|
||||
elif receipt == 700:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
cond = problem_ctx.dtype in ["fp16", "bf16"]
|
||||
cond &= kernel_ctx.pipeline.F_vlayout == "row"
|
||||
cond &= kernel_ctx.pipeline.F_qscale == "no"
|
||||
cond &= kernel_ctx.pipeline.F_lse == "t"
|
||||
cond &= kernel_ctx.pipeline.F_skip == "f"
|
||||
cond &= kernel_ctx.pipeline.F_sink == "f"
|
||||
cond &= kernel_ctx.pipeline.F_logits == "f"
|
||||
return cond
|
||||
|
||||
return Product(name="TransformerEngine integration", rule=fit)
|
||||
elif receipt == 888:
|
||||
|
||||
def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
|
||||
@@ -139,7 +139,9 @@ if __name__ == "__main__":
|
||||
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
|
||||
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
|
||||
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration\n"
|
||||
+ " 700: Only generate instance for TransformerEngine integration (fwd + bwd, fp16/bf16 only,\n"
|
||||
+ " invariants: row vlayout, has_lse, no skip/sink/logits/qscale)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user