diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 9be811cacd..02055ffd9e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -29,6 +29,7 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include "fmha_bwd.hpp" + """ FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """ @@ -167,6 +168,13 @@ int fmha_bwd_dq_dk_dv_maxq_() return k_::kMaxSeqLenQ; }} +template <> +int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetDqAccSplits(seqlen_k); +}} + template <> std::string fmha_bwd_dq_dk_dv_get_name_() {{ @@ -179,34 +187,17 @@ std::string fmha_bwd_dq_dk_dv_get_name_() FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp" FMHA_BWD_API = """ -#include - -template -float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - if constexpr (!std::is_same_v) - {{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} - ); - }} - else - {{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} - ); - }} +fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{ + [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); +{F_launcher} + run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }}; + dq_acc_splits = 1; }} + +// Prefer to use launcher. Leave fmha_bwd here for backward compatibility. template <> -float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ +float fmha_bwd<2>(const fmha_bwd_traits& t, fmha_bwd_args a, const ck_tile::stream_config& s){{ [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); float r = -1; {F_dispatch} @@ -225,15 +216,25 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, if_i=0) -> str: return "\n".join(lines) + "\n" -FMHA_BWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && +FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; +""" +FMHA_BWD_API_INNER_DISPATCH_RUN = """ r = fmha_bwd_, {F_arch.tag}>(s, a); return r; }} """ +FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """ + run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{ + return fmha_bwd_, {F_arch.tag}>(s, a); + }}; + dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_(t.max_seqlen_k); + return; +}} +""" # M0 size for 1d kernels (dot/convert) M0_1D = 64 @@ -795,35 +796,35 @@ class FmhaBwdApiTrait: if self.mode == "group": return "true /*spad1d is always true in group mode*/" elif self.spad1d == "t": - return f"true /*a.seqlen_q % {M0_1D} != 0*/" + return f"true /*t.seqlen_q % {M0_1D} != 0*/" else: # self.spad1d == "f" - return f"a.seqlen_q % {M0_1D} == 0" + return f"t.seqlen_q % {M0_1D} == 0" @property def dcheck(self) -> str: if self.dpad == 0: - return f"a.hdim_q % {self.bhdq} == 0" + return f"t.hdim_q % {self.bhdq} == 0" else: - return f"a.hdim_q % {self.dpad} == 0" + return f"t.hdim_q % {self.dpad} == 0" @property def dvcheck(self) -> str: if self.dvpad == 0: - return f"a.hdim_v % {self.bhdv} == 0" + return f"t.hdim_v % {self.bhdv} == 0" else: - return f"a.hdim_v % {self.dvpad} == 0" + return f"t.hdim_v % {self.dvpad} == 0" @property def max_seq_q_cond(self) -> str: if self.tile.max_seq_q != 0: - return f" && (a.seqlen_q <= {self.tile.max_seq_q})" + return f" && (t.seqlen_q <= {self.tile.max_seq_q})" else: return "" @property def extra_cond(self) -> str: if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: - return " && (a.seqlen_k <= 256)" + return " && (t.seqlen_k <= 256)" else: return "" @@ -910,12 +911,12 @@ class FmhaBwdApiPool: check_duplicates_and_paddings(ts, trait) ts.append(copy.copy(trait)) - def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> str: + def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> tuple[str, str]: inners = "" + inners_launcher = "" for i_trait, trait in enumerate(traits): - inners += FMHA_BWD_API_INNER_DISPATCH.format( + inners_common = FMHA_BWD_API_INNER_DISPATCH_COMMON.format( F_if=if_(i_trait), - F_arch=trait.arch, F_mode=MODE_MAP[trait.mode], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], @@ -935,13 +936,20 @@ class FmhaBwdApiPool: F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, - F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_max_seq_q_cond=trait.max_seq_q_cond, F_cond_extra=trait.extra_cond, F_bn0=trait.tile.F_bn0, F_convert_dq_bn0=trait.convert_dq_bn0, ) - return inners + inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format( + F_arch=trait.arch, + F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], + ) + inners_launcher += inners_common + FMHA_BWD_API_INNER_DISPATCH_LAUNCHER.format( + F_arch=trait.arch, + F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], + ) + return inners, inners_launcher @staticmethod def max_seq_q_sort_key(trait): @@ -957,8 +965,7 @@ class FmhaBwdApiPool: def hdim_cond(hdim: int) -> str: return f"t.hdim_q <= {hdim} && t.hdim_v <= {hdim}" - @property - def api(self) -> str: + def _api_per_arch(self, variant) -> str: per_arch = "" for i_arch, (arch, pool_by_arch) in enumerate(self.dq_dk_dv_pool.items()): per_dtypes = "" @@ -968,7 +975,7 @@ class FmhaBwdApiPool: traits = sorted(pool_by_hdim, key=self.max_seq_q_sort_key) inners = self._api_inners(traits) per_hdim_case += FMHA_BWD_API_COND_STATEMENT( - if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners + if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners[variant] ) per_dtypes += FMHA_BWD_API_COND_STATEMENT( if_i=i_dtype, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case @@ -978,9 +985,13 @@ class FmhaBwdApiPool: ) if not per_arch: # empty string we add some ignore to suppress warning in api - per_arch = "(void)t; (void)s; (void)a;" + per_arch = ("(void)t; (void)s; (void)a;", "(void)t;")[variant] + return per_arch + @property + def api(self) -> str: result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format( - F_dispatch=indent(per_arch) + F_dispatch=indent(self._api_per_arch(0)), + F_launcher=indent(self._api_per_arch(1)), ) return result.replace("\n\n", "\n") diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 180d039cd4..983ac50231 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -14,6 +14,8 @@ #include #include #include +#include +#include struct FmhaBwdFp32 { @@ -463,6 +465,8 @@ template std::string fmha_bwd_dq_dk_dv_get_name_(); template int fmha_bwd_dq_dk_dv_maxq_(); +template +int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k); template struct fmha_bwd_dot_do_o_traits_ @@ -503,11 +507,18 @@ void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_convert_dq_get_name_(); -// This is the public API, will be generated by script +// Traits that are used to dispatch different kernel implementations for fmha backward struct fmha_bwd_traits { + int seqlen_q; + int seqlen_k; + int batch; + int max_seqlen_q; + int max_seqlen_k; int hdim_q; int hdim_v; + int nhead_q; + int nhead_k; std::string data_type; bool is_group_mode; mask_enum mask_type; @@ -518,5 +529,52 @@ struct fmha_bwd_traits bool is_deterministic; // TODO: padding check is inside this api }; + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{ + if constexpr(!std::is_same_v) + { + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" + << fmha_bwd_convert_dq_get_name_() << "@" + << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_(s_, a); }, + [=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_(s_, a); }, + [=](const ck_tile::stream_config& s_) { + fmha_bwd_convert_dq_oneshot_(s_, a); + }); + } + else + { + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" + << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_(s_, a); }, + [=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_(s_, a); }); + } +} + template -float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); +float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_config&); + +struct fmha_bwd_launcher +{ + std::function run{}; + ck_tile::index_t dq_acc_splits{0}; + + fmha_bwd_launcher(const fmha_bwd_traits&); + + template + float operator()(Args&&... args) const + { + return run(std::forward(args)...); + } +}; diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index f41f0668e5..92ae94d9b1 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -56,8 +56,6 @@ auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) return ck_tile::make_tuple(rtol, atol); } -extern template float fmha_bwd<2>(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); - template bwd_result fmha_bwd_run(mode_enum mode, ck_tile::index_t batch, @@ -243,12 +241,29 @@ bwd_result fmha_bwd_run(mode_enum mode, (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back()); - // Keep it equal to or smaller than minimal bn0 of all tiles in fmha_bwd.py - // TODO: add API for requesting kN0/nsplits/workspace_size? It is not safe to rely on internal - // implementation details in client code. - const ck_tile::index_t kN0 = 16; - const ck_tile::index_t nsplits = - deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1; + + const fmha_bwd_traits fmha_traits{ + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + data_type, + mode == mode_enum::group, + mask.type, + bias.type, + use_dbias, + p_drop > 0.0f, + s_randval, + deterministic, + }; + fmha_bwd_launcher launcher(fmha_traits); + + const ck_tile::index_t nsplits = launcher.dq_acc_splits; ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -406,17 +421,7 @@ bwd_result fmha_bwd_run(mode_enum mode, : "") << ", mask:" << mask << std::flush; - auto fmha_traits = fmha_bwd_traits{hdim_q, - hdim_v, - data_type, - mode == mode_enum::group, - mask.type, - bias.type, - use_dbias, - p_drop > 0.0f, - s_randval, - deterministic}; - auto fmha_args = [&]() { + auto fmha_args = [&]() { /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// 'nhead_stride_bias' are 0. @@ -478,7 +483,7 @@ bwd_result fmha_bwd_run(mode_enum mode, k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - : bias_buf.GetDeviceBuffer(), + : bias_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), do_buf.GetDeviceBuffer(), @@ -509,7 +514,7 @@ bwd_result fmha_bwd_run(mode_enum mode, stride_k, stride_v, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) - : stride_bias, + : stride_bias, stride_o, stride_randval, stride_do, @@ -553,7 +558,7 @@ bwd_result fmha_bwd_run(mode_enum mode, drop_seed_offset}; }(); - const float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); + const float ave_time = launcher(fmha_args, stream_config); if(ave_time < 0) { std::cout << ", not supported yet" << std::flush << std::endl; @@ -844,7 +849,7 @@ bwd_result fmha_bwd_run(mode_enum mode, dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; - fmha_bwd(fmha_traits, fmha_args, stream_config_v); + launcher(fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); dk_buf.FromDevice(dk_host.data()); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index b5d3f490ed..ee9f87c525 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -124,6 +124,13 @@ struct FmhaBwdDQDKDVKernel #undef _TS_ // clang-format on } + CK_TILE_HOST static index_t GetDqAccSplits(index_t seqlen_k) + { + if constexpr(kIsDeterministic) + return integer_divide_ceil(seqlen_k, FmhaPipeline::BlockFmhaShape::kN0); + else + return 1; + } template // to avoid duplicated base class prblem, introduce an template // arg