diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 20589ada57..e94d40bc72 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -369,13 +369,7 @@ float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - using traits2_tuple_ = - std::tuple< - fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, 128, 32, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}> - , fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, 128, 64, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}> - >; - using traits2_ = std::tuple_element_t<(64 <= {F_hdim}), - traits2_tuple_>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); }} @@ -876,7 +870,6 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ if direction == 'fwd': if dtype == 'fp16' or dtype == 'bf16': return { - # original '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), @@ -899,8 +892,8 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(direction : str, dtype : s return { '32' : FmhaFwdSplitKVCombineTileSize(128, 32, -1), '64' : FmhaFwdSplitKVCombineTileSize(128, 64, -1), - '128' : FmhaFwdSplitKVCombineTileSize(128, 64, -1), - '256' : FmhaFwdSplitKVCombineTileSize(128, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(128, 128, -1), + '256' : FmhaFwdSplitKVCombineTileSize(128, 256, -1), } elif dtype == 'fp8' or dtype == 'bf8': return {