Remove fmha_fwd_dispatch()

This commit is contained in:
PoYen, Chen
2024-08-08 08:15:04 +00:00
parent 291e9b4bbb
commit 247e135cfc

View File

@@ -253,22 +253,6 @@ int override_num_splits_if_necessary(
return num_splits;
}
float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < args.num_splits || args.block_table_ptr != nullptr)
{
return fmha_fwd_splitkv(traits, args, config);
}
else
#endif
{
return fmha_fwd(traits, args, config);
}
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
@@ -1003,7 +987,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
{drop_seed, drop_offset}};
}();
const float fwd_ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < args.num_splits || args.block_table_ptr != nullptr)
{
return fmha_fwd_splitkv(traits, args, config);
}
#endif
return fmha_fwd(traits, args, config);
}();
if(appendkv_ave_time < 0 || fwd_ave_time < 0)
{