This commit is contained in:
slippedJim
2025-07-17 15:24:19 +08:00
committed by GitHub
parent 28072adc3a
commit 05b65d0c7c

View File

@@ -536,7 +536,6 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
@@ -544,13 +543,11 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= dpad == dvpad
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())