Clearn up generate.py

This commit is contained in:
PoYen, Chen
2024-06-11 14:15:07 +00:00
parent bb6804e315
commit c9bbb7b142

View File

@@ -369,13 +369,7 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_tuple_ =
std::tuple<
fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, 128, 32, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>
, fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, 128, 64, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>
>;
using traits2_ = std::tuple_element_t<(64 <= {F_hdim}),
traits2_tuple_>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
@@ -876,7 +870,6 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
if direction == 'fwd':
if dtype == 'fp16' or dtype == 'bf16':
return {
# original
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
@@ -899,8 +892,8 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(direction : str, dtype : s
return {
'32' : FmhaFwdSplitKVCombineTileSize(128, 32, -1),
'64' : FmhaFwdSplitKVCombineTileSize(128, 64, -1),
'128' : FmhaFwdSplitKVCombineTileSize(128, 64, -1),
'256' : FmhaFwdSplitKVCombineTileSize(128, 64, -1),
'128' : FmhaFwdSplitKVCombineTileSize(128, 128, -1),
'256' : FmhaFwdSplitKVCombineTileSize(128, 256, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {