From 489602f9a88fb7e979ffbb5dda5c32289d92be22 Mon Sep 17 00:00:00 2001 From: wenchenvincent <32376000+wenchenvincent@users.noreply.github.com> Date: Mon, 10 Mar 2025 22:35:27 -0500 Subject: [PATCH] Enabled bwd support for hdim_qk != hdim_v for TE integration. (#1965) --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 220e275115..a41c7a9069 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1980,6 +1980,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue + # CK tile example elif receipt == 3: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] @@ -2000,7 +2001,6 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> elif receipt == 5: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'bias', 'alibi'] - cond &= dpad == dvpad cond &= deterministic == "f" if not cond: continue