revert incorrect operations in bwd generation

This commit is contained in:
amd-ruitang3
2025-06-10 08:25:07 +00:00
parent 160788cdf4
commit efdf31e26f

View File

@@ -2775,8 +2775,8 @@ class FmhaBwdDQDKDVKernel:
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
# '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr_iglp", "kr_ktr_vr"],
'32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
@@ -2827,9 +2827,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# CK tile example
elif receipt == 3:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no']
cond &= dbias in ['no']
cond &= dropout in ['no']
cond &= bias in ['no', 'alibi']
cond &= dpad == dvpad
cond &= deterministic == "f"
if not cond: