From eac0f3cc47e08a7ea0d9baa8a29f749c712ae2b9 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Sun, 2 Jun 2024 20:11:49 -0400 Subject: [PATCH] Fix mismatched return type --- example/ck_tile/01_fmha/fmha_fwd.hpp | 4 ++-- example/ck_tile/01_fmha/generate.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 71a80c4d61..1cfb2f0e15 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -431,13 +431,13 @@ template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); template -float fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); template std::string fmha_fwd_splitkv_get_name_(); template -float fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); template std::string fmha_fwd_splitkv_combine_get_name_(); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 2d103558ea..8453028b3c 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -232,13 +232,13 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F #include template<> -float fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} template<> @@ -312,13 +312,13 @@ using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F #include template<> -float fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} template<>