only output the deterministic bwd kernel for aiter (#1903)

* only output the deterministic kernel

* Add comment
This commit is contained in:
rocking
2025-02-20 04:27:01 +08:00
committed by GitHub
parent f0d49d14fc
commit e4358c01d9
4 changed files with 15 additions and 2 deletions

View File

@@ -492,6 +492,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'alibi']
@@ -506,6 +507,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= deterministic == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= bias in ['no', 'bias']
@@ -514,22 +516,24 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= deterministic == "f"
if not cond:
continue
# Aiter (mha_bwd) integration
elif receipt == 10:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "f"
cond &= deterministic == "t"
if not cond:
continue
# Aiter (mha_varlen_bwd) integration
elif receipt == 11:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
cond &= bias in ['no', 'alibi']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= deterministic == "f"
cond &= deterministic == "t"
if not cond:
continue
api_pool.register_dq_dk_dv_traits(k.api_trait())

View File

@@ -487,6 +487,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
@@ -494,6 +495,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
@@ -501,6 +503,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 10:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"
@@ -509,6 +512,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 11:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"

View File

@@ -326,6 +326,8 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# 2 - Flash attention integration
# 12 - Aiter(mha_fwd_kvcache) integration
if receipt in (2, 12):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'

View File

@@ -705,6 +705,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
@@ -712,6 +713,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 11:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "group"
@@ -720,6 +722,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_fwd_kvcache) integration
elif receipt == 12:
cond = dtype in ['fp16', 'bf16']
cond &= mode == "batch"