mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Use Unified Workspace for FMHA BWD
This commit is contained in:
@@ -169,17 +169,29 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const fmha_bwd_traits& t)
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(int batch_size)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k);
|
||||
return k_::GetWorkspaceHostSize(batch_size);
|
||||
}}
|
||||
|
||||
template <>
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
void* cpu_ws, ck_tile::index_t batch_size, ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t nhead_q, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k,
|
||||
const ck_tile::index_t* seqstart_qs, const ck_tile::index_t* seqstart_ks)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::NeedsZeroDqAcc();
|
||||
return k_::PrepareWorkspaceHost(
|
||||
cpu_ws, batch_size, hdim_q, nhead_q, seqlen_q, seqlen_k, seqstart_qs, seqstart_ks);
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(
|
||||
void* device_ws, const void* host_ws, size_t device_ws_size, size_t host_ws_size)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
k_::PrepareWorkspaceDevice(device_ws, host_ws, device_ws_size, host_ws_size);
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -197,9 +209,6 @@ FMHA_BWD_API = """
|
||||
fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
|
||||
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
|
||||
{F_launcher}
|
||||
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
|
||||
dq_acc_splits = 1;
|
||||
needs_zero_dq_acc = false;
|
||||
}}
|
||||
|
||||
|
||||
@@ -228,7 +237,7 @@ FMHA_BWD_API_INNER_DISPATCH_COMMON = """{F_if}((t.is_group_mode == {F_mode}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_max_seq_q_cond}{F_cond_extra}) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}>;
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_RUN = """
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
@@ -236,11 +245,7 @@ FMHA_BWD_API_INNER_DISPATCH_RUN = """
|
||||
}}
|
||||
"""
|
||||
FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
|
||||
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
|
||||
}};
|
||||
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t);
|
||||
needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_, {F_arch.tag}>();
|
||||
this->init<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(t);
|
||||
return;
|
||||
}}
|
||||
"""
|
||||
@@ -649,7 +654,6 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
{F_bm0},
|
||||
{F_bn0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
@@ -666,8 +670,7 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_mode},
|
||||
{F_spad},
|
||||
{F_dpad},
|
||||
{F_deterministic},
|
||||
{F_bn0}>;
|
||||
{F_deterministic}>;
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
@@ -711,7 +714,6 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_spad: str # true/false
|
||||
F_dpad: str #
|
||||
F_mode: str # value from MODE_MAP
|
||||
@@ -727,7 +729,6 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_bm0,
|
||||
F_bn0=self.F_bn0,
|
||||
F_spad=BOOL_MAP[self.F_spad],
|
||||
F_dpad=BOOL_MAP[self.F_dpad],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
@@ -748,7 +749,7 @@ class FmhaBwdConvertQGradKernel:
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
@@ -837,10 +838,6 @@ class FmhaBwdApiTrait:
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def convert_dq_bn0(self) -> int:
|
||||
return self.tile.F_bn0 if self.deterministic == "t" else 0
|
||||
|
||||
@property
|
||||
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
@@ -895,7 +892,6 @@ class FmhaBwdApiTrait:
|
||||
F_hdim=self.hdim,
|
||||
F_dtype=self.dtype,
|
||||
F_bm0=M0_1D,
|
||||
F_bn0=self.convert_dq_bn0,
|
||||
F_spad=self.spad1d,
|
||||
F_dpad=F_dpad,
|
||||
F_mode=self.mode,
|
||||
@@ -948,7 +944,6 @@ class FmhaBwdApiPool:
|
||||
F_max_seq_q_cond=trait.max_seq_q_cond,
|
||||
F_cond_extra=trait.extra_cond,
|
||||
F_bn0=trait.tile.F_bn0,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0,
|
||||
)
|
||||
inners += inners_common + FMHA_BWD_API_INNER_DISPATCH_RUN.format(
|
||||
F_arch=trait.arch,
|
||||
|
||||
@@ -11,11 +11,12 @@
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
|
||||
struct FmhaBwdFp32
|
||||
{
|
||||
@@ -115,7 +116,7 @@ struct fmha_bwd_args
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
void* workspace_ptr;
|
||||
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
@@ -125,13 +126,13 @@ struct fmha_bwd_args
|
||||
// With padding:
|
||||
// Group mode:
|
||||
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
|
||||
// lengths. [array size: batch + 1]
|
||||
// lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
|
||||
// sequence. [array size: batch]
|
||||
// sequence. [array size: batch]
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// sequence lengths. [array size: batch + 1]
|
||||
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
|
||||
// exclusive. Use one set, not both.
|
||||
// exclusive. Use one set, not both.
|
||||
//
|
||||
// Batch mode:
|
||||
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
|
||||
@@ -178,7 +179,6 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
@@ -191,7 +191,6 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::long_index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
@@ -204,12 +203,10 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::long_index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
@@ -224,12 +221,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
|
||||
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
|
||||
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
|
||||
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
|
||||
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
|
||||
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
@@ -241,10 +232,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.workspace_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
@@ -263,7 +255,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -274,11 +266,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
@@ -295,10 +286,11 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dq_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
dq_ptr,
|
||||
args.workspace_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.batch,
|
||||
@@ -313,7 +305,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
stride_dq,
|
||||
args.stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -324,7 +316,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
@@ -335,11 +327,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
batch_stride_dq,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
@@ -403,8 +394,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
|
||||
args.dq_ptr,
|
||||
args.batch,
|
||||
args.nhead_q,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
@@ -413,27 +406,20 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
args.nhead_stride_dq);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.workspace_ptr,
|
||||
args.dq_ptr,
|
||||
args.batch,
|
||||
args.nhead_q,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc,
|
||||
args.batch,
|
||||
args.nhead_q);
|
||||
args.batch_stride_dq);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -471,9 +457,21 @@ template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
struct fmha_bwd_traits;
|
||||
template <typename Traits_, typename Arch = void>
|
||||
int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t);
|
||||
size_t fmha_bwd_dq_dk_dv_dq_ws_host_size_(int batch_size);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
|
||||
size_t fmha_bwd_dq_dk_dv_dq_prepare_ws_host_(void* cpu_ws,
|
||||
ck_tile::index_t batch_size,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t nhead_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
const ck_tile::index_t* seqstart_qs,
|
||||
const ck_tile::index_t* seqstart_ks);
|
||||
template <typename Traits_, typename Arch = void>
|
||||
void fmha_bwd_dq_dk_dv_dq_prepare_ws_device_(void* device_ws,
|
||||
const void* host_ws,
|
||||
size_t device_ws_size,
|
||||
size_t host_ws_size);
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
@@ -499,8 +497,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_,
|
||||
ck_tile::index_t kN0>
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
};
|
||||
@@ -534,6 +531,8 @@ struct fmha_bwd_traits
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
const int* seqstart_qs = nullptr;
|
||||
const int* seqstart_ks = nullptr;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
|
||||
@@ -574,11 +573,55 @@ float fmha_bwd(const fmha_bwd_traits&, fmha_bwd_args, const ck_tile::stream_conf
|
||||
|
||||
struct fmha_bwd_launcher
|
||||
{
|
||||
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
|
||||
ck_tile::index_t dq_acc_splits{0};
|
||||
bool needs_zero_dq_acc{true};
|
||||
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{
|
||||
[](fmha_bwd_args, const ck_tile::stream_config&) {
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping run\n";
|
||||
return -1.0f;
|
||||
}};
|
||||
size_t workspace_size = 0;
|
||||
std::function<void(void*)> prepare_workspace{[](void*) {
|
||||
std::cerr << "fmha_bwd: no kernel found for given traits, skipping prepare_workspace\n";
|
||||
}};
|
||||
|
||||
fmha_bwd_launcher(const fmha_bwd_traits&);
|
||||
fmha_bwd_launcher(fmha_bwd_launcher&&) = delete;
|
||||
fmha_bwd_launcher& operator=(fmha_bwd_launcher&&) = delete;
|
||||
|
||||
private:
|
||||
size_t host_ws_size = 0;
|
||||
size_t device_ws_size = 0;
|
||||
std::unique_ptr<char[]> ws_host;
|
||||
|
||||
public:
|
||||
template <typename T0 /*dot_do_o_trait*/,
|
||||
typename T1 /*dq_dk_dv_trait*/,
|
||||
typename T2 /*convert_dq_trait*/,
|
||||
typename Arch>
|
||||
void init(const fmha_bwd_traits& traits)
|
||||
{
|
||||
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {
|
||||
return fmha_bwd_<T0, T1, T2, Arch>(s, a);
|
||||
};
|
||||
host_ws_size = fmha_bwd_dq_dk_dv_dq_ws_host_size_<T1, Arch>(traits.batch);
|
||||
if(host_ws_size > 0)
|
||||
{
|
||||
ws_host = std::make_unique<char[]>(host_ws_size); // TODO: support host mem allocator
|
||||
device_ws_size = fmha_bwd_dq_dk_dv_dq_prepare_ws_host_<T1, Arch>( //
|
||||
ws_host.get(),
|
||||
traits.batch,
|
||||
traits.hdim_q,
|
||||
traits.nhead_q,
|
||||
traits.seqlen_q,
|
||||
traits.seqlen_k,
|
||||
traits.seqstart_qs,
|
||||
traits.seqstart_ks);
|
||||
}
|
||||
workspace_size = host_ws_size + device_ws_size;
|
||||
prepare_workspace = [this](void* device_ws) {
|
||||
fmha_bwd_dq_dk_dv_dq_prepare_ws_device_<T1, Arch>(
|
||||
device_ws, ws_host.get(), device_ws_size, host_ws_size);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
float operator()(Args&&... args) const
|
||||
|
||||
@@ -260,10 +260,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic,
|
||||
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
|
||||
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
|
||||
};
|
||||
fmha_bwd_launcher launcher(fmha_traits);
|
||||
|
||||
const ck_tile::index_t nsplits = launcher.dq_acc_splits;
|
||||
const size_t ws_size = launcher.workspace_size;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
@@ -301,8 +302,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
|
||||
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
@@ -377,7 +376,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem ws_buf(ws_size);
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -409,17 +409,14 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
// clang-format on
|
||||
|
||||
const std::size_t workspace_size_in_megabytes =
|
||||
ck_tile::integer_divide_ceil(dq_acc_host.get_element_space_size_in_bytes(), 1024 * 1024);
|
||||
ck_tile::integer_divide_ceil(ws_size, 1024 * 1024);
|
||||
|
||||
std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm)
|
||||
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
||||
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
|
||||
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
|
||||
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
|
||||
<< (deterministic
|
||||
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
|
||||
"MiB|" + std::to_string(nsplits) + "splits"
|
||||
: "")
|
||||
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
|
||||
<< ", mask:" << mask << std::flush;
|
||||
|
||||
auto fmha_args = [&]() {
|
||||
@@ -437,7 +434,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
|
||||
const auto split_stride_dq_acc = (shape_seqlen_q * hdim_q);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
@@ -449,8 +445,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
const auto nhead_stride_dq_acc =
|
||||
static_cast<ck_tile::long_index_t>(split_stride_dq_acc) * nsplits;
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
@@ -463,7 +457,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const auto batch_stride_dq_acc = nhead * nhead_stride_dq_acc;
|
||||
|
||||
void* ws_ptr = ws_size > 0 ? ws_buf.GetDeviceBuffer() : nullptr;
|
||||
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
@@ -494,7 +489,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dk_buf.GetDeviceBuffer(),
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
ws_ptr,
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
seqlen_q_ptr_dev,
|
||||
@@ -519,8 +514,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
hdim_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_q, // stride_dq (same layout as q for dq output)
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
@@ -532,7 +526,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
@@ -545,12 +538,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
@@ -833,19 +824,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
|
||||
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
|
||||
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
|
||||
ck_tile::FillConstant<AccDataType>{ck_tile::numeric<AccDataType>::infinity()}(dq_acc_host);
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
dq_acc_buf.ToDevice(dq_acc_host.data());
|
||||
// re-initialize workspace for validation run
|
||||
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dbias_buf.SetZero();
|
||||
|
||||
if(launcher.needs_zero_dq_acc)
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};
|
||||
launcher(fmha_args, stream_config_v);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -95,7 +95,6 @@ template <typename AccDataType_,
|
||||
typename QGradDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM0_,
|
||||
index_t kN0_,
|
||||
index_t kQKHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
@@ -111,7 +110,6 @@ struct BlockFmhaBwdConvertQGradPipelineProblem
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
static constexpr index_t kQKHeaddim = kQKHeaddim_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
Reference in New Issue
Block a user