mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
add fmha fwd splitkv receipt for aiter c++ api (#2068)
* add s_randval for c++ api * Fix bug of bias in splitkv --------- Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
@@ -545,10 +545,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode in ["batch", "group"]
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
@@ -689,6 +688,11 @@ def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaB
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
@@ -841,6 +845,11 @@ def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[Fmh
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
@@ -536,10 +536,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter aiter::mha_fwd integration
|
||||
elif receipt == 500:
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode in ['batch', 'group']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
|
||||
@@ -738,6 +738,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd_splikv C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
@@ -796,6 +803,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd_splikv C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
@@ -109,8 +109,8 @@ if __name__ == "__main__":
|
||||
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
|
||||
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
|
||||
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration"
|
||||
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
|
||||
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user