[rocm-libraries] ROCm/rocm-libraries#4577 (commit a36922c)

[CK_TILE] FMHA BWD Launcher Interface

## Motivation
Reduce memory usage; Be prepared to implement optimizations of reducing
nsplits in deterministic cases.

## Technical Details
This PR introduces a new launcher interface for the FMHA backward
operation, replacing direct function calls with a more structured
approach. The launcher encapsulates kernel dispatch logic and provides
access to computed metadata like the number of dQ acc splits.

**Changes:**
- Added `fmha_bwd_launcher` class that wraps kernel execution and
exposes `dq_acc_splits`
- Moved `fmha_bwd_traits` construction earlier in the execution flow to
support launcher initialization
- Refactored code generation to produce both legacy API and new launcher
constructor

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-03-04 01:21:07 +00:00
committed by assistant-librarian[bot]
parent 08b6de62f8
commit b09112bbad
4 changed files with 150 additions and 69 deletions

View File

@@ -29,6 +29,7 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "fmha_bwd.hpp"
"""
FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """
@@ -167,6 +168,13 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
return k_::kMaxSeqLenQ;
}}
template <>
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(ck_tile::index_t seqlen_k)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetDqAccSplits(seqlen_k);
}}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
{{
@@ -179,34 +187,17 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp"
FMHA_BWD_API = """
#include <iostream>
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_, typename Arch>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_, Arch>(s_, a); }}
);
}}
else
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }}
);
}}
fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_launcher}
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
dq_acc_splits = 1;
}}
// Prefer to use launcher. Leave fmha_bwd here for backward compatibility.
template <>
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float fmha_bwd<2>(const fmha_bwd_traits& t, fmha_bwd_args a, const ck_tile::stream_config& s){{
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
float r = -1;
{F_dispatch}
@@ -225,15 +216,25 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, if_i=0) -> str:
return "\n".join(lines) + "\n"
FMHA_BWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
"""
FMHA_BWD_API_INNER_DISPATCH_RUN = """
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
return r;
}}
"""
FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
}};
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t.max_seqlen_k);
return;
}}
"""
# M0 size for 1d kernels (dot/convert)
M0_1D = 64
@@ -795,35 +796,35 @@ class FmhaBwdApiTrait:
if self.mode == "group":
return "true /*spad1d is always true in group mode*/"
elif self.spad1d == "t":
return f"true /*a.seqlen_q % {M0_1D} != 0*/"
return f"true /*t.seqlen_q % {M0_1D} != 0*/"
else: # self.spad1d == "f"
return f"a.seqlen_q % {M0_1D} == 0"
return f"t.seqlen_q % {M0_1D} == 0"
@property
def dcheck(self) -> str:
if self.dpad == 0:
return f"a.hdim_q % {self.bhdq} == 0"
return f"t.hdim_q % {self.bhdq} == 0"
else:
return f"a.hdim_q % {self.dpad} == 0"
return f"t.hdim_q % {self.dpad} == 0"
@property
def dvcheck(self) -> str:
if self.dvpad == 0:
return f"a.hdim_v % {self.bhdv} == 0"
return f"t.hdim_v % {self.bhdv} == 0"
else:
return f"a.hdim_v % {self.dvpad} == 0"
return f"t.hdim_v % {self.dvpad} == 0"
@property
def max_seq_q_cond(self) -> str:
if self.tile.max_seq_q != 0:
return f" && (a.seqlen_q <= {self.tile.max_seq_q})"
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
else:
return ""
@property
def extra_cond(self) -> str:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
return " && (a.seqlen_k <= 256)"
return " && (t.seqlen_k <= 256)"
else:
return ""
@@ -910,12 +911,12 @@ class FmhaBwdApiPool:
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> str:
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> tuple[str, str]:
inners = ""
inners_launcher = ""
for i_trait, trait in enumerate(traits):
inners += FMHA_BWD_API_INNER_DISPATCH.format(
inners_common = FMHA_BWD_API_INNER_DISPATCH_COMMON.format(
F_if=if_(i_trait),
F_arch=trait.arch,
F_mode=MODE_MAP[trait.mode],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
@@ -935,13 +936,20 @@ class FmhaBwdApiPool:
F_deterministic=BOOL_MAP[trait.deterministic],
F_trload=BOOL_MAP[trait.tr_load],
F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
F_max_seq_q_cond=trait.max_seq_q_cond,
F_cond_extra=trait.extra_cond,
F_bn0=trait.tile.F_bn0,
F_convert_dq_bn0=trait.convert_dq_bn0,
)
return inners
inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format(
F_arch=trait.arch,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
)
inners_launcher += inners_common + FMHA_BWD_API_INNER_DISPATCH_LAUNCHER.format(
F_arch=trait.arch,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
)
return inners, inners_launcher
@staticmethod
def max_seq_q_sort_key(trait):
@@ -957,8 +965,7 @@ class FmhaBwdApiPool:
def hdim_cond(hdim: int) -> str:
return f"t.hdim_q <= {hdim} && t.hdim_v <= {hdim}"
@property
def api(self) -> str:
def _api_per_arch(self, variant) -> str:
per_arch = ""
for i_arch, (arch, pool_by_arch) in enumerate(self.dq_dk_dv_pool.items()):
per_dtypes = ""
@@ -968,7 +975,7 @@ class FmhaBwdApiPool:
traits = sorted(pool_by_hdim, key=self.max_seq_q_sort_key)
inners = self._api_inners(traits)
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners[variant]
)
per_dtypes += FMHA_BWD_API_COND_STATEMENT(
if_i=i_dtype, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case
@@ -978,9 +985,13 @@ class FmhaBwdApiPool:
)
if not per_arch:
# empty string we add some ignore to suppress warning in api
per_arch = "(void)t; (void)s; (void)a;"
per_arch = ("(void)t; (void)s; (void)a;", "(void)t;")[variant]
return per_arch
@property
def api(self) -> str:
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(
F_dispatch=indent(per_arch)
F_dispatch=indent(self._api_per_arch(0)),
F_launcher=indent(self._api_per_arch(1)),
)
return result.replace("\n\n", "\n")

View File

@@ -14,6 +14,8 @@
#include <type_traits>
#include <utility>
#include <variant>
#include <iostream>
#include <functional>
struct FmhaBwdFp32
{
@@ -463,6 +465,8 @@ template <typename Traits_, typename Arch = void>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_maxq_();
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k);
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
@@ -503,11 +507,18 @@ void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_, typename Arch = void>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script
// Traits that are used to dispatch different kernel implementations for fmha backward
struct fmha_bwd_traits
{
int seqlen_q;
int seqlen_k;
int batch;
int max_seqlen_q;
int max_seqlen_k;
int hdim_q;
int hdim_v;
int nhead_q;
int nhead_k;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
@@ -518,5 +529,52 @@ struct fmha_bwd_traits
bool is_deterministic;
// TODO: padding check is inside this api
};
template <typename T0 /*dot_do_o_trait*/,
typename T1 /*dq_dk_dv_trait*/,
typename T2 /*convert_dq_trait*/,
typename Arch>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{
if constexpr(!std::is_same_v<T2, void>)
{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<T0, Arch>() << "@"
<< fmha_bwd_convert_dq_get_name_<T2, Arch>() << "@"
<< fmha_bwd_dq_dk_dv_get_name_<T1, Arch>() << std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_<T0, Arch>(s_, a); },
[=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_<T1, Arch>(s_, a); },
[=](const ck_tile::stream_config& s_) {
fmha_bwd_convert_dq_oneshot_<T2, Arch>(s_, a);
});
}
else
{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<T0, Arch>() << "@"
<< fmha_bwd_dq_dk_dv_get_name_<T1, Arch>() << std::flush;
return ck_tile::launch_kernel(
s,
[=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_<T0, Arch>(s_, a); },
[=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_<T1, Arch>(s_, a); });
}
}
template <int Version = 2>
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_config&);
struct fmha_bwd_launcher
{
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
ck_tile::index_t dq_acc_splits{0};
fmha_bwd_launcher(const fmha_bwd_traits&);
template <typename... Args>
float operator()(Args&&... args) const
{
return run(std::forward<Args>(args)...);
}
};

View File

@@ -56,8 +56,6 @@ auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
return ck_tile::make_tuple(rtol, atol);
}
extern template float fmha_bwd<2>(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
template <typename DataTypeConfig>
bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::index_t batch,
@@ -243,12 +241,29 @@ bwd_result fmha_bwd_run(mode_enum mode,
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
// Keep it equal to or smaller than minimal bn0 of all tiles in fmha_bwd.py
// TODO: add API for requesting kN0/nsplits/workspace_size? It is not safe to rely on internal
// implementation details in client code.
const ck_tile::index_t kN0 = 16;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
const fmha_bwd_traits fmha_traits{
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic,
};
fmha_bwd_launcher launcher(fmha_traits);
const ck_tile::index_t nsplits = launcher.dq_acc_splits;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
@@ -406,17 +421,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
: "")
<< ", mask:" << mask << std::flush;
auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() {
auto fmha_args = [&]() {
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
@@ -478,7 +483,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
: bias_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
do_buf.GetDeviceBuffer(),
@@ -509,7 +514,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
stride_k,
stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
: stride_bias,
stride_o,
stride_randval,
stride_do,
@@ -553,7 +558,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
drop_seed_offset};
}();
const float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
const float ave_time = launcher(fmha_args, stream_config);
if(ave_time < 0)
{
std::cout << ", not supported yet" << std::flush << std::endl;
@@ -844,7 +849,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
launcher(fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());
dk_buf.FromDevice(dk_host.data());