mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4577 (commit a36922c)
[CK_TILE] FMHA BWD Launcher Interface ## Motivation Reduce memory usage; Be prepared to implement optimizations of reducing nsplits in deterministic cases. ## Technical Details This PR introduces a new launcher interface for the FMHA backward operation, replacing direct function calls with a more structured approach. The launcher encapsulates kernel dispatch logic and provides access to computed metadata like the number of dQ acc splits. **Changes:** - Added `fmha_bwd_launcher` class that wraps kernel execution and exposes `dq_acc_splits` - Moved `fmha_bwd_traits` construction earlier in the execution flow to support launcher initialization - Refactored code generation to produce both legacy API and new launcher constructor ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
08b6de62f8
commit
b09112bbad
@@ -29,6 +29,7 @@ FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "fmha_bwd.hpp"
|
||||
|
||||
"""
|
||||
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY = """
|
||||
@@ -167,6 +168,13 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
return k_::kMaxSeqLenQ;
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(ck_tile::index_t seqlen_k)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetDqAccSplits(seqlen_k);
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
{{
|
||||
@@ -179,34 +187,17 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
|
||||
FMHA_BWD_API_FILENAME = "fmha_bwd_api.cpp"
|
||||
FMHA_BWD_API = """
|
||||
#include <iostream>
|
||||
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_, typename Arch>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_, Arch>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
else
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_, Arch>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_, Arch>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_, Arch>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
{F_launcher}
|
||||
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
|
||||
dq_acc_splits = 1;
|
||||
}}
|
||||
|
||||
|
||||
// Prefer to use launcher. Leave fmha_bwd here for backward compatibility.
|
||||
template <>
|
||||
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
float fmha_bwd<2>(const fmha_bwd_traits& t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
@@ -225,15 +216,25 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, if_i=0) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_RUN = """
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
|
||||
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
}};
|
||||
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t.max_seqlen_k);
|
||||
return;
|
||||
}}
|
||||
"""
|
||||
|
||||
# M0 size for 1d kernels (dot/convert)
|
||||
M0_1D = 64
|
||||
@@ -795,35 +796,35 @@ class FmhaBwdApiTrait:
|
||||
if self.mode == "group":
|
||||
return "true /*spad1d is always true in group mode*/"
|
||||
elif self.spad1d == "t":
|
||||
return f"true /*a.seqlen_q % {M0_1D} != 0*/"
|
||||
return f"true /*t.seqlen_q % {M0_1D} != 0*/"
|
||||
else: # self.spad1d == "f"
|
||||
return f"a.seqlen_q % {M0_1D} == 0"
|
||||
return f"t.seqlen_q % {M0_1D} == 0"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 0:
|
||||
return f"a.hdim_q % {self.bhdq} == 0"
|
||||
return f"t.hdim_q % {self.bhdq} == 0"
|
||||
else:
|
||||
return f"a.hdim_q % {self.dpad} == 0"
|
||||
return f"t.hdim_q % {self.dpad} == 0"
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 0:
|
||||
return f"a.hdim_v % {self.bhdv} == 0"
|
||||
return f"t.hdim_v % {self.bhdv} == 0"
|
||||
else:
|
||||
return f"a.hdim_v % {self.dvpad} == 0"
|
||||
return f"t.hdim_v % {self.dvpad} == 0"
|
||||
|
||||
@property
|
||||
def max_seq_q_cond(self) -> str:
|
||||
if self.tile.max_seq_q != 0:
|
||||
return f" && (a.seqlen_q <= {self.tile.max_seq_q})"
|
||||
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
|
||||
return " && (a.seqlen_k <= 256)"
|
||||
return " && (t.seqlen_k <= 256)"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@@ -910,12 +911,12 @@ class FmhaBwdApiPool:
|
||||
check_duplicates_and_paddings(ts, trait)
|
||||
ts.append(copy.copy(trait))
|
||||
|
||||
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> str:
|
||||
def _api_inners(self, traits: List[FmhaBwdApiTrait]) -> tuple[str, str]:
|
||||
inners = ""
|
||||
inners_launcher = ""
|
||||
for i_trait, trait in enumerate(traits):
|
||||
inners += FMHA_BWD_API_INNER_DISPATCH.format(
|
||||
inners_common = FMHA_BWD_API_INNER_DISPATCH_COMMON.format(
|
||||
F_if=if_(i_trait),
|
||||
F_arch=trait.arch,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
@@ -935,13 +936,20 @@ class FmhaBwdApiPool:
|
||||
F_deterministic=BOOL_MAP[trait.deterministic],
|
||||
F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
|
||||
F_max_seq_q_cond=trait.max_seq_q_cond,
|
||||
F_cond_extra=trait.extra_cond,
|
||||
F_bn0=trait.tile.F_bn0,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0,
|
||||
)
|
||||
return inners
|
||||
inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format(
|
||||
F_arch=trait.arch,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
|
||||
)
|
||||
inners_launcher += inners_common + FMHA_BWD_API_INNER_DISPATCH_LAUNCHER.format(
|
||||
F_arch=trait.arch,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled],
|
||||
)
|
||||
return inners, inners_launcher
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_sort_key(trait):
|
||||
@@ -957,8 +965,7 @@ class FmhaBwdApiPool:
|
||||
def hdim_cond(hdim: int) -> str:
|
||||
return f"t.hdim_q <= {hdim} && t.hdim_v <= {hdim}"
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
def _api_per_arch(self, variant) -> str:
|
||||
per_arch = ""
|
||||
for i_arch, (arch, pool_by_arch) in enumerate(self.dq_dk_dv_pool.items()):
|
||||
per_dtypes = ""
|
||||
@@ -968,7 +975,7 @@ class FmhaBwdApiPool:
|
||||
traits = sorted(pool_by_hdim, key=self.max_seq_q_sort_key)
|
||||
inners = self._api_inners(traits)
|
||||
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners
|
||||
if_i=i_hdim, F_cond=self.hdim_cond(hdim), F_body=inners[variant]
|
||||
)
|
||||
per_dtypes += FMHA_BWD_API_COND_STATEMENT(
|
||||
if_i=i_dtype, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case
|
||||
@@ -978,9 +985,13 @@ class FmhaBwdApiPool:
|
||||
)
|
||||
if not per_arch:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_arch = "(void)t; (void)s; (void)a;"
|
||||
per_arch = ("(void)t; (void)s; (void)a;", "(void)t;")[variant]
|
||||
return per_arch
|
||||
@property
|
||||
def api(self) -> str:
|
||||
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(
|
||||
F_dispatch=indent(per_arch)
|
||||
F_dispatch=indent(self._api_per_arch(0)),
|
||||
F_launcher=indent(self._api_per_arch(1)),
|
||||
)
|
||||
return result.replace("\n\n", "\n")
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
|
||||
struct FmhaBwdFp32
|
||||
{
|
||||
@@ -463,6 +465,8 @@ template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k);
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
@@ -503,11 +507,18 @@ void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
// Traits that are used to dispatch different kernel implementations for fmha backward
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
int seqlen_q;
|
||||
int seqlen_k;
|
||||
int batch;
|
||||
int max_seqlen_q;
|
||||
int max_seqlen_k;
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
int nhead_q;
|
||||
int nhead_k;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
@@ -518,5 +529,52 @@ struct fmha_bwd_traits
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
template <typename T0 /*dot_do_o_trait*/,
|
||||
typename T1 /*dq_dk_dv_trait*/,
|
||||
typename T2 /*convert_dq_trait*/,
|
||||
typename Arch>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
if constexpr(!std::is_same_v<T2, void>)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<T0, Arch>() << "@"
|
||||
<< fmha_bwd_convert_dq_get_name_<T2, Arch>() << "@"
|
||||
<< fmha_bwd_dq_dk_dv_get_name_<T1, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_<T0, Arch>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_<T1, Arch>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_bwd_convert_dq_oneshot_<T2, Arch>(s_, a);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<T0, Arch>() << "@"
|
||||
<< fmha_bwd_dq_dk_dv_get_name_<T1, Arch>() << std::flush;
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dot_do_o_oneshot_<T0, Arch>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_) { fmha_bwd_dq_dk_dv_oneshot_<T1, Arch>(s_, a); });
|
||||
}
|
||||
}
|
||||
|
||||
template <int Version = 2>
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_bwd_launcher
|
||||
{
|
||||
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
|
||||
ck_tile::index_t dq_acc_splits{0};
|
||||
|
||||
fmha_bwd_launcher(const fmha_bwd_traits&);
|
||||
|
||||
template <typename... Args>
|
||||
float operator()(Args&&... args) const
|
||||
{
|
||||
return run(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -56,8 +56,6 @@ auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
extern template float fmha_bwd<2>(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
@@ -243,12 +241,29 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
|
||||
// Keep it equal to or smaller than minimal bn0 of all tiles in fmha_bwd.py
|
||||
// TODO: add API for requesting kN0/nsplits/workspace_size? It is not safe to rely on internal
|
||||
// implementation details in client code.
|
||||
const ck_tile::index_t kN0 = 16;
|
||||
const ck_tile::index_t nsplits =
|
||||
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
|
||||
|
||||
const fmha_bwd_traits fmha_traits{
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
};
|
||||
fmha_bwd_launcher launcher(fmha_traits);
|
||||
|
||||
const ck_tile::index_t nsplits = launcher.dq_acc_splits;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
@@ -406,17 +421,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
: "")
|
||||
<< ", mask:" << mask << std::flush;
|
||||
|
||||
auto fmha_traits = fmha_bwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic};
|
||||
auto fmha_args = [&]() {
|
||||
auto fmha_args = [&]() {
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
/// 'nhead_stride_bias' are 0.
|
||||
@@ -478,7 +483,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
do_buf.GetDeviceBuffer(),
|
||||
@@ -509,7 +514,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
stride_k,
|
||||
stride_v,
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
: stride_bias,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
@@ -553,7 +558,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
drop_seed_offset};
|
||||
}();
|
||||
|
||||
const float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
const float ave_time = launcher(fmha_args, stream_config);
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
@@ -844,7 +849,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
launcher(fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
dk_buf.FromDevice(dk_host.data());
|
||||
|
||||
Reference in New Issue
Block a user