diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 9b229c7ad4..220e275115 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1996,6 +1996,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + # TE integration + elif receipt == 5: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue # Aiter (mha_bwd) integration elif receipt == 300: cond = dtype in ['fp16', 'bf16']