mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#4577 (commit a36922c)
[CK_TILE] FMHA BWD Launcher Interface ## Motivation Reduce memory usage; Be prepared to implement optimizations of reducing nsplits in deterministic cases. ## Technical Details This PR introduces a new launcher interface for the FMHA backward operation, replacing direct function calls with a more structured approach. The launcher encapsulates kernel dispatch logic and provides access to computed metadata like the number of dQ acc splits. **Changes:** - Added `fmha_bwd_launcher` class that wraps kernel execution and exposes `dq_acc_splits` - Moved `fmha_bwd_traits` construction earlier in the execution flow to support launcher initialization - Refactored code generation to produce both legacy API and new launcher constructor ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
08b6de62f8
commit
b09112bbad
@@ -29,6 +29,7 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "fmha_bwd.hpp"
|
||||
|
||||
"""
|
||||
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """
|
||||
@@ -167,6 +168,13 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
return k_::kMaxSeqLenQ;
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(ck_tile::index_t seqlen_k)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetDqAccSplits(seqlen_k);
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
@@ -179,34 +187,17 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
|
||||
FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp"
|
||||
FMHA_BWD_API = """
|
||||
#include <iostream>
|
||||
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_, typename Arch>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_, Arch>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
else
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
{F_launcher}
|
||||
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
|
||||
dq_acc_splits = 1;
|
||||
}}
|
||||
|
||||
|
||||
// Prefer to use launcher. Leave fmha_bwd here for backward compatibility.
|
||||
template <>
|
||||
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
float fmha_bwd<2>(const fmha_bwd_traits& t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
@@ -225,15 +216,25 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, if_i=0) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_RUN = """
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
|
||||
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
}};
|
||||
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t.max_seqlen_k);
|
||||
return;
|
||||
}}
|
||||
"""
|
||||
|
||||
# M0 size for 1d kernels (dot/convert)
|
||||
M0_1D = 64
|
||||
@@ -795,35 +796,35 @@ class FmhaBwdApiTrait:
|
||||
if self.mode == "group":
|
||||
return "true /*spad1d is always true in group mode*/"
|
||||
elif self.spad1d == "t":
|
||||
return f"true /*a.seqlen_q % {M0_1D} != 0*/"
|
||||
return f"true /*t.seqlen_q % {M0_1D} != 0*/"
|
||||
else: # self.spad1d == "f"
|
||||
return f"a.seqlen_q % {M0_1D} == 0"
|
||||
return f"t.seqlen_q % {M0_1D} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 0:
|
||||
return f"a.hdim_q % {self.bhdq} == 0"
|
||||
return f"t.hdim_q % {self.bhdq} == 0"
|
||||
else:
|
||||
return f"a.hdim_q % {self.dpad} == 0"
|
||||
return f"t.hdim_q % {self.dpad} == 0"
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 0:
|
||||
return f"a.hdim_v % {self.bhdv} == 0"
|
||||
return f"t.hdim_v % {self.bhdv} == 0"
|
||||
else:
|
||||
return f"a.hdim_v % {self.dvpad} == 0"
|
||||
return f"t.hdim_v % {self.dvpad} == 0"
|
||||
|
||||
@property
|
||||
def max_seq_q_cond(self) -> str:
|
||||
if self.tile.max_seq_q != 0:
|
||||
return f" && (a.seqlen_q <= {self.tile.max_seq_q})"
|
||||
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
|
||||
return " && (a.seqlen_k <= 256)"
|
||||
return " && (t.seqlen_k <= 256)"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -910,12 +911,12 @@ class FmhaBwdApiPool:
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> str:
|
||||
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> tuple[str, str]:
|
||||
inners = ""
|
||||
inners_launcher = ""
|
||||
for i_trait, trait in enumerate(traits):
|
||||
inners += FMHA_BWD_API_INNER_DISPATCH.format(
|
||||
inners_common = FMHA_BWD_API_INNER_DISPATCH_COMMON.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=trait.arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
@@ -935,13 +936,20 @@ class FmhaBwdApiPool:
|
||||
F_deterministic=BOOL_MAP[trait.deterministic],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
|
||||
F_max_seq_q_cond=trait.max_seq_q_cond,
|
||||
F_cond_extra=trait.extra_cond,
|
||||
F_bn0=trait.tile.F_bn0,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0,
|
||||
)
|
||||
return inners
|
||||
inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format(
|
||||
F_arch=trait.arch,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
|
||||
)
|
||||
inners_launcher += inners_common + FMHA_BWD_API_INNER_DISPATCH_LAUNCHER.format(
|
||||
F_arch=trait.arch,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
|
||||
)
|
||||
return inners, inners_launcher
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_sort_key(trait):
|
||||
@@ -957,8 +965,7 @@ class FmhaBwdApiPool:
|
||||
def hdim_cond(hdim: int) -> str:
|
||||
return f"t.hdim_q <= {hdim} && t.hdim_v <= {hdim}"
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
def _api_per_arch(self, variant) -> str:
|
||||
per_arch = ""
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.dq_dk_dv_pool.items()):
|
||||
per_dtypes = ""
|
||||
@@ -968,7 +975,7 @@ class FmhaBwdApiPool:
|
||||
traits = sorted(pool_by_hdim, key=self.max_seq_q_sort_key)
|
||||
inners = self._api_inners(traits)
|
||||
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners
|
||||
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners[variant]
|
||||
)
|
||||
per_dtypes += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_i=i_dtype, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case
|
||||
@@ -978,9 +985,13 @@ class FmhaBwdApiPool:
|
||||
)
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
per_arch = ("(void)t; (void)s; (void)a;", "(void)t;")[variant]
|
||||
return per_arch
|
||||
@property
|
||||
def api(self) -> str:
|
||||
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(
|
||||
F_dispatch=indent(per_arch)
|
||||
F_dispatch=indent(self._api_per_arch(0)),
|
||||
F_launcher=indent(self._api_per_arch(1)),
|
||||
)
|
||||
return result.replace("\n\n", "\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user