From 291e9b4bbb0cefa6e378e2306f15c5a149ae04f3 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Aug 2024 08:07:03 +0000 Subject: [PATCH] Separate splitkv/non-splitkv args/traits --- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 12 +-- example/ck_tile/01_fmha/fmha_fwd.hpp | 93 ++++++++++--------- 2 files changed, 54 insertions(+), 51 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c531da1ccf..228bc5f3b7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -87,7 +87,7 @@ using fmha_kernel = fmha_pipeline, fmha_epilogue>; -static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); @@ -105,7 +105,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F #include template<> -void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if constexpr({F_mode} == false) {{ // batch mode // make sure F_bn0 is divisible by F_bk1 @@ -163,7 +163,7 @@ using fmha_kernel = fmha_pipeline, fmha_epilogue>; -static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); @@ -180,7 +180,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m #include template<> -void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if (a.num_splits <= 16) {{ kernel_runner<4>::run(s, a); @@ -206,7 +206,7 @@ FMHA_FWD_SPLITKV_API=""" #include template -float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) +float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if(s.log_level_ > 0) std::cout @@ -220,7 +220,7 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) ); }} -float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ +float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 8dbef4491e..36becf1b1a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -156,6 +156,8 @@ struct fmha_fwd_args std::tuple drop_seed_offset; }; +using fmha_fwd_splitkv_args = fmha_fwd_args; + struct fmha_fwd_appendkv_args { void* q_ptr; @@ -301,7 +303,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } template -auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { @@ -417,7 +419,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) } template -auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { @@ -611,53 +613,54 @@ struct fmha_fwd_traits_ template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); -template -struct fmha_fwd_splitkv_traits_ : fmha_fwd_traits_ + bool kPadS_, + bool kPadSK_, + bool kPadD_, + bool kPadDv_> +struct fmha_fwd_splitkv_traits_ { - static constexpr bool kIsPagedKV = kIsPagedKV_; + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsPagedKV = kIsPagedKV_; }; template -void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); template std::string fmha_fwd_splitkv_get_name_(); @@ -685,7 +688,7 @@ struct fmha_fwd_splitkv_combine_traits_ }; template -void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); template std::string fmha_fwd_splitkv_combine_get_name_();