[CK_TILE] FMHA BWD Optimization For GFX950 (#2628)

* simplify fmha_bwd_kernel MakeKargs & dq_dram_window

* simply duplicate

* trload pipeline

* Try two-stage

* add prefetch

* optimize & iglp
This commit is contained in:
Yi DING
2025-08-12 11:11:55 +08:00
committed by GitHub
parent a7badc6ec5
commit 4fde1646e5
16 changed files with 2216 additions and 586 deletions

View File

@@ -62,6 +62,12 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
#if defined(__gfx950__)
static constexpr bool kIsAvialable = true;
#else
static constexpr bool kIsAvialable = !kUseTrLoad;
#endif
// clang-format off
template <typename T> struct t2s;
@@ -99,7 +105,7 @@ struct FmhaBwdDQDKDVKernel
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" );
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
#undef _SS_
#undef _TS_
// clang-format on
@@ -298,6 +304,24 @@ struct FmhaBwdDQDKDVKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <typename... Ts>
CK_TILE_HOST static constexpr Kargs
MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <typename... Ts>
CK_TILE_HOST static constexpr Kargs
MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
@@ -466,248 +490,6 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
@@ -854,208 +636,6 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
const void* lse_ptr,
const void* do_ptr,
const void* d_ptr,
void* rand_val_ptr,
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
do_ptr,
d_ptr,
rand_val_ptr,
dk_ptr,
dv_ptr,
dbias_ptr,
dq_acc_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_randval,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
split_stride_dq_acc,
window_size_left,
window_size_right,
mask_type,
p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
@@ -1082,6 +662,12 @@ struct FmhaBwdDQDKDVKernel
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(kIsAvialable)
run_(std::move(kargs));
}
CK_TILE_DEVICE void run_(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
@@ -1282,62 +868,33 @@ struct FmhaBwdDQDKDVKernel
{0, 0});
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
if constexpr(kIsDeterministic)
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kIsDeterministic)
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
else
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
}();
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
else
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
memory_operation_enum::set, memory_operation_enum::atomic_add);
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
const auto dq_acc_dram = pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}();
auto lse_dram_window =