mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Clearn up generate.py
This commit is contained in:
@@ -369,13 +369,7 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using traits2_tuple_ =
|
||||
std::tuple<
|
||||
fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, 128, 32, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>
|
||||
, fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, 128, 64, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>
|
||||
>;
|
||||
using traits2_ = std::tuple_element_t<(64 <= {F_hdim}),
|
||||
traits2_tuple_>;
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
@@ -876,7 +870,6 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
|
||||
if direction == 'fwd':
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
# original
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
|
||||
@@ -899,8 +892,8 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(direction : str, dtype : s
|
||||
return {
|
||||
'32' : FmhaFwdSplitKVCombineTileSize(128, 32, -1),
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(128, 64, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(128, 64, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(128, 64, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(128, 128, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(128, 256, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user