mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] FMHA BWD Optimization For GFX950 (#2628)
* simplify fmha_bwd_kernel MakeKargs & dq_dram_window * simply duplicate * trload pipeline * Try two-stage * add prefetch * optimize & iglp
This commit is contained in:
@@ -62,6 +62,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
|
||||
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
|
||||
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvialable = true;
|
||||
#else
|
||||
static constexpr bool kIsAvialable = !kUseTrLoad;
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -99,7 +105,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" );
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -298,6 +304,24 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST static constexpr Kargs
|
||||
MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST static constexpr Kargs
|
||||
MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
@@ -466,248 +490,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dk,
|
||||
ck_tile::index_t batch_stride_dv,
|
||||
ck_tile::index_t batch_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dk,
|
||||
ck_tile::index_t batch_stride_dv,
|
||||
ck_tile::index_t batch_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
@@ -854,208 +636,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
@@ -1082,6 +662,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if constexpr(kIsAvialable)
|
||||
run_(std::move(kargs));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void run_(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
@@ -1282,62 +868,33 @@ struct FmhaBwdDQDKDVKernel
|
||||
{0, 0});
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
AccDataType* dq_acc_ptr =
|
||||
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kIsDeterministic)
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
else
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
}();
|
||||
|
||||
auto dq_acc_dram = [&]() {
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
AccDataType* dq_acc_ptr =
|
||||
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
|
||||
auto dq_acc_dram = [&]() {
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::atomic_add>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
}
|
||||
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
const auto dq_acc_dram = pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
}();
|
||||
|
||||
auto lse_dram_window =
|
||||
|
||||
Reference in New Issue
Block a user