diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 01616650c3..9627c2bbf5 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -85,6 +85,7 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -180,8 +181,12 @@ bool run(const ck_tile::ArgParser& arg_parser) int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); - ck_tile::stream_config stream_config{ - nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); @@ -734,7 +739,8 @@ bool run(const ck_tile::ArgParser& arg_parser) dq_buf.SetZero(); dbias_buf.SetZero(); - ck_tile::stream_config stream_config_v{nullptr, true, /* log_level = */ (kname ? 1 : 0), 0, 1}; + ck_tile::stream_config stream_config_v{ + nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 9aaa3e3f23..0c6b468951 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -318,6 +318,12 @@ struct fmha_bwd_dq_dk_dv_traits_ template float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + template struct fmha_bwd_dot_do_o_traits_ { @@ -331,6 +337,12 @@ struct fmha_bwd_dot_do_o_traits_ template float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + // This is the public API, will be generated by script struct fmha_bwd_traits { diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index a265a23c8d..2f7898d206 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -671,12 +671,42 @@ float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); }} """ FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); +}} + float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} @@ -697,8 +727,7 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; - r = fmha_bwd_dot_do_o_(s, a); - r += fmha_bwd_dq_dk_dv_(s, a); + r = fmha_bwd_(s, a); return r; }} """ @@ -1008,14 +1037,35 @@ using fmha_bwd_dot_do_o_kernel_{F_idx} = using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; +#include + template<> float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); }} """ diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index e0fff41c69..5160c89f29 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1187,6 +1187,32 @@ struct FmhaBwdOGradDotOKernel static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ; static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV; + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { +// sync with generate.py +// clang-format off + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn); + #undef _SS_ + #undef _TS_ + // clang-format on + } + // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs.