Enabled bwd support for hdim_qk != hdim_v for TE integration. (#1965)

This commit is contained in:
wenchenvincent
2025-03-10 22:35:27 -05:00
committed by GitHub
parent de3c6bf585
commit 489602f9a8

View File

@@ -1980,6 +1980,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= dpad == dvpad
if not cond:
continue
# CK tile example
elif receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
@@ -2000,7 +2001,6 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif receipt == 5:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'bias', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond:
continue