add fmha fwd splitkv receipt for aiter c++ api (#2068)

* add s_randval for c++ api

* Fix bug of bias in splitkv

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
slippedJim
2025-04-10 23:21:13 +08:00
committed by GitHub
parent f14e648e7c
commit 5f885d2b7a
5 changed files with 30 additions and 10 deletions

View File

@@ -545,10 +545,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= dpad == dvpad
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= mode in ["batch", "group"]
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
if not cond:
continue
@@ -689,6 +688,11 @@ def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaB
cond &= mode == "group"
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
return gen
@@ -841,6 +845,11 @@ def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[Fmh
cond &= mode == "group"
if not cond:
continue
# aiter::mha_bwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
return gen

View File

@@ -536,10 +536,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter aiter::mha_fwd integration
elif receipt == 500:
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= mode in ['batch', 'group']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:

View File

@@ -738,6 +738,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
@@ -796,6 +803,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
cond &= mode == "group"
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
gen.append(k)
return gen