[CK_TILE] Use Unified Workspace for FMHA BWD

This commit is contained in:
Ding, Yi
2026-04-03 02:10:43 -05:00
parent 2510e7b238
commit 28afc8fee3
5 changed files with 527 additions and 324 deletions

View File

@@ -169,17 +169,29 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
}}
template <>
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const fmha_bwd_traits& t)
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(int batch_size)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k);
return k_::GetWorkspaceHostSize(batch_size);
}}
template <>
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,
ck_tile::index_t nhead_q, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k,
const ck_tile::index_t* seqstart_qs, const ck_tile::index_t* seqstart_ks)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::NeedsZeroDqAcc();
return k_::PrepareWorkspaceHost(
cpu_ws, batch_size, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_qs, seqstart_ks);
}}
template <>
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
void* device_ws, const void* host_ws, size_t device_ws_size, size_t host_ws_size)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
k_::PrepareWorkspaceDevice(device_ws, host_ws, device_ws_size, host_ws_size);
}}
template <>
@@ -197,9 +209,6 @@ FMHA_BWD_API = """
fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_launcher}
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
dq_acc_splits = 1;
needs_zero_dq_acc = false;
}}
@@ -228,7 +237,7 @@ FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}>;
"""
FMHA_BWD_API_INNER_DISPATCH_RUN = """
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
@@ -236,11 +245,7 @@ FMHA_BWD_API_INNER_DISPATCH_RUN = """
}}
"""
FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
}};
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t);
needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_, {F_arch.tag}>();
this->init<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(t);
return;
}}
"""
@@ -649,7 +654,6 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic},
@@ -666,8 +670,7 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic},
{F_bn0}>;
{F_deterministic}>;
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
@@ -711,7 +714,6 @@ class FmhaBwdConvertQGradKernel:
F_hdim: int # hdim
F_dtype: str # data type
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_spad: str # true/false
F_dpad: str #
F_mode: str # value from MODE_MAP
@@ -727,7 +729,6 @@ class FmhaBwdConvertQGradKernel:
F_hdim=self.F_hdim,
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_bm0,
F_bn0=self.F_bn0,
F_spad=BOOL_MAP[self.F_spad],
F_dpad=BOOL_MAP[self.F_dpad],
F_mode=MODE_MAP[self.F_mode],
@@ -748,7 +749,7 @@ class FmhaBwdConvertQGradKernel:
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}"
if pn != "":
n += f"_{pn}"
else:
@@ -837,10 +838,6 @@ class FmhaBwdApiTrait:
else:
return ""
@property
def convert_dq_bn0(self) -> int:
return self.tile.F_bn0 if self.deterministic == "t" else 0
@property
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
@@ -895,7 +892,6 @@ class FmhaBwdApiTrait:
F_hdim=self.hdim,
F_dtype=self.dtype,
F_bm0=M0_1D,
F_bn0=self.convert_dq_bn0,
F_spad=self.spad1d,
F_dpad=F_dpad,
F_mode=self.mode,
@@ -948,7 +944,6 @@ class FmhaBwdApiPool:
F_max_seq_q_cond=trait.max_seq_q_cond,
F_cond_extra=trait.extra_cond,
F_bn0=trait.tile.F_bn0,
F_convert_dq_bn0=trait.convert_dq_bn0,
)
inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format(
F_arch=trait.arch,

View File

@@ -11,11 +11,12 @@
#include "mask.hpp"
#include "bias.hpp"
#include <functional>
#include <iostream>
#include <memory>
#include <type_traits>
#include <utility>
#include <variant>
#include <iostream>
#include <functional>
struct FmhaBwdFp32
{
@@ -115,7 +116,7 @@ struct fmha_bwd_args
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
void* workspace_ptr;
// Usage notes for sequence length pointer parameters:
//
@@ -125,13 +126,13 @@ struct fmha_bwd_args
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
@@ -178,7 +179,6 @@ struct fmha_bwd_args
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
@@ -191,7 +191,6 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::long_index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
@@ -204,12 +203,10 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::long_index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
@@ -224,12 +221,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
@@ -241,10 +232,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
dq_ptr,
args.workspace_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
@@ -263,7 +255,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
stride_dq,
args.stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -274,11 +266,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
nhead_stride_dq,
args.nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
@@ -295,10 +286,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
dq_ptr,
args.workspace_ptr,
args.seqlen_q,
args.seqlen_k,
args.batch,
@@ -313,7 +305,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
stride_dq,
args.stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -324,7 +316,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
nhead_stride_dq,
args.nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
@@ -335,11 +327,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
batch_stride_dq,
args.batch_stride_dq,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
@@ -403,8 +394,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
args.dq_ptr,
args.batch,
args.nhead_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
@@ -413,27 +406,20 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.cu_seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
args.nhead_stride_dq);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
args.dq_ptr,
args.batch,
args.nhead_q,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc,
args.batch,
args.nhead_q);
args.batch_stride_dq);
}
}();
@@ -471,9 +457,21 @@ template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_maxq_();
struct fmha_bwd_traits;
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t);
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
template <typename Traits_, typename Arch = void>
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
ck_tile::index_t batch_size,
ck_tile::index_t hdim_q,
ck_tile::index_t nhead_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
const ck_tile::index_t* seqstart_qs,
const ck_tile::index_t* seqstart_ks);
template <typename Traits_, typename Arch = void>
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_(void* device_ws,
const void* host_ws,
size_t device_ws_size,
size_t host_ws_size);
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
@@ -499,8 +497,7 @@ template <ck_tile::index_t HDim_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_,
ck_tile::index_t kN0>
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
};
@@ -534,6 +531,8 @@ struct fmha_bwd_traits
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
const int* seqstart_qs = nullptr;
const int* seqstart_ks = nullptr;
// TODO: padding check is inside this api
};
@@ -574,11 +573,55 @@ float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_conf
struct fmha_bwd_launcher
{
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
ck_tile::index_t dq_acc_splits{0};
bool needs_zero_dq_acc{true};
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{
[](fmha_bwd_args, const ck_tile::stream_config&) {
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
return -1.0f;
}};
size_t workspace_size = 0;
std::function<void(void*)> prepare_workspace{[](void*) {
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
}};
fmha_bwd_launcher(const fmha_bwd_traits&);
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
private:
size_t host_ws_size = 0;
size_t device_ws_size = 0;
std::unique_ptr<char[]> ws_host;
public:
template <typename T0 /*dot_do_o_trait*/,
typename T1 /*dq_dk_dv_trait*/,
typename T2 /*convert_dq_trait*/,
typename Arch>
void init(const fmha_bwd_traits& traits)
{
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
};
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
if(host_ws_size > 0)
{
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
ws_host.get(),
traits.batch,
traits.hdim_q,
traits.nhead_q,
traits.seqlen_q,
traits.seqlen_k,
traits.seqstart_qs,
traits.seqstart_ks);
}
workspace_size = host_ws_size + device_ws_size;
prepare_workspace = [this](void* device_ws) {
fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<T1, Arch>(
device_ws, ws_host.get(), device_ws_size, host_ws_size);
};
}
template <typename... Args>
float operator()(Args&&... args) const

View File

@@ -260,10 +260,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
};
fmha_bwd_launcher launcher(fmha_traits);
const ck_tile::index_t nsplits = launcher.dq_acc_splits;
const size_t ws_size = launcher.workspace_size;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
@@ -301,8 +302,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
if(init_method == "ui" || init_method == "0")
{
@@ -377,7 +376,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem ws_buf(ws_size);
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -409,17 +409,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
// clang-format on
const std::size_t workspace_size_in_megabytes =
ck_tile::integer_divide_ceil(dq_acc_host.get_element_space_size_in_bytes(), 1024 * 1024);
ck_tile::integer_divide_ceil(ws_size, 1024 * 1024);
std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm)
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
<< (deterministic
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
"MiB|" + std::to_string(nsplits) + "splits"
: "")
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
<< ", mask:" << mask << std::flush;
auto fmha_args = [&]() {
@@ -437,7 +434,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
@@ -449,8 +445,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
const auto nhead_stride_dq_acc =
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
@@ -463,7 +457,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;
void* ws_ptr = ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr;
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
@@ -494,7 +489,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
ws_ptr,
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
seqlen_q_ptr_dev,
@@ -519,8 +514,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
stride_o,
stride_randval,
stride_do,
hdim_q, // stride_dq_acc
stride_q, // stride_dq
stride_q, // stride_dq (same layout as q for dq output)
stride_dk,
stride_dv,
stride_dbias,
@@ -532,7 +526,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
@@ -545,12 +538,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
@@ -833,19 +824,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
ck_tile::FillConstant<AccDataType>{ck_tile::numeric<AccDataType>::infinity()}(dq_acc_host);
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
dq_acc_buf.ToDevice(dq_acc_host.data());
// re-initialize workspace for validation run
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dbias_buf.SetZero();
if(launcher.needs_zero_dq_acc)
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
launcher(fmha_args, stream_config_v);

File diff suppressed because it is too large Load Diff

View File

@@ -95,7 +95,6 @@ template <typename AccDataType_,
typename QGradDataType_,
index_t kBlockSize_,
index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
bool kIsGroupMode_,
bool kIsDeterministic_,
@@ -111,7 +110,6 @@ struct BlockFmhaBwdConvertQGradPipelineProblem
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;