mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] FMHA BWD Optimization For GFX950 (#2628)
* simplify fmha_bwd_kernel MakeKargs & dq_dram_window * simply duplicate * trload pipeline * Try two-stage * add prefetch * optimize & iglp
This commit is contained in:
@@ -25,8 +25,10 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp"
|
||||
|
||||
@@ -62,6 +62,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
|
||||
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
|
||||
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvialable = true;
|
||||
#else
|
||||
static constexpr bool kIsAvialable = !kUseTrLoad;
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -99,7 +105,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" );
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -298,6 +304,24 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST static constexpr Kargs
|
||||
MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST static constexpr Kargs
|
||||
MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
@@ -466,248 +490,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dk,
|
||||
ck_tile::index_t batch_stride_dv,
|
||||
ck_tile::index_t batch_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_lsed,
|
||||
ck_tile::index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dk,
|
||||
ck_tile::index_t batch_stride_dv,
|
||||
ck_tile::index_t batch_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_dq_acc,
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
@@ -854,208 +636,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* do_ptr,
|
||||
const void* d_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
ck_tile::index_t stride_dk,
|
||||
ck_tile::index_t stride_dv,
|
||||
ck_tile::index_t stride_dbias,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t nhead_stride_dk,
|
||||
ck_tile::index_t nhead_stride_dv,
|
||||
ck_tile::index_t nhead_stride_dbias,
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
lse_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
rand_val_ptr,
|
||||
dk_ptr,
|
||||
dv_ptr,
|
||||
dbias_ptr,
|
||||
dq_acc_ptr,
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
nhead_ratio_qk,
|
||||
scale,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_dq_acc,
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_dq_acc,
|
||||
nhead_stride_dk,
|
||||
nhead_stride_dv,
|
||||
nhead_stride_dbias,
|
||||
split_stride_dq_acc,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
p_drop,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
@@ -1082,6 +662,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if constexpr(kIsAvialable)
|
||||
run_(std::move(kargs));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void run_(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
@@ -1282,62 +868,33 @@ struct FmhaBwdDQDKDVKernel
|
||||
{0, 0});
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
AccDataType* dq_acc_ptr =
|
||||
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kIsDeterministic)
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
else
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
}();
|
||||
|
||||
auto dq_acc_dram = [&]() {
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
AccDataType* dq_acc_ptr =
|
||||
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
|
||||
auto dq_acc_dram = [&]() {
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::atomic_add>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
}
|
||||
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
const auto dq_acc_dram = pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
}();
|
||||
|
||||
auto lse_dram_window =
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
|
||||
@@ -54,6 +54,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -654,9 +656,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
}();
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
Policy::template PTFromGemm0CToGemm1A<Problem,
|
||||
decltype(pt_reg_tensor),
|
||||
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
|
||||
Policy::template PTFromGemm0CToGemm1A<Problem>(pt_reg_tensor, p_gemm);
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
auto qt_reg_tensor = load_tile(qt_lds_read_window);
|
||||
@@ -728,9 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem,
|
||||
decltype(dst_reg_tensor),
|
||||
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
|
||||
Policy::template SGradTFromGemm2CToGemm3A<Problem>(dst_reg_tensor, ds_gemm);
|
||||
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
|
||||
@@ -6,22 +6,30 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, typename Policy>
|
||||
class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
{
|
||||
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
|
||||
|
||||
public:
|
||||
using type = std::conditional_t<has_dpad,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<Problem>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<Problem>>;
|
||||
template <typename... TS>
|
||||
using type_ =
|
||||
std::conditional_t<Problem::kUseTrLoad,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>,
|
||||
std::conditional_t<has_dpad,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;
|
||||
using type = std::conditional_t<std::is_same_v<Policy, void>, //
|
||||
type_<Problem>,
|
||||
type_<Problem, Policy>>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem>::type
|
||||
template <typename Problem, typename Policy = void>
|
||||
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem, Policy>::type
|
||||
{
|
||||
public:
|
||||
static constexpr const char* name = "auto";
|
||||
|
||||
@@ -0,0 +1,760 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
|
||||
// using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(kUseTrLoad, "This pipeline uses trload!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "trload_kr_ktr_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking)
|
||||
return (raw_lse == -numeric<LSEDataType>::infinity()) //
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
else
|
||||
return raw_lse;
|
||||
};
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_DEVICE auto operator()( //
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
// init VGrad & KGrad
|
||||
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
|
||||
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
|
||||
|
||||
// K, HBM ->LDS ->Reg
|
||||
auto k_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<KDataType>(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
k_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
const auto k_origin = k_dram_window.get_window_origin();
|
||||
|
||||
// Early termination
|
||||
const auto [seqlen_q_start, seqlen_q_end] =
|
||||
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here dk_acc&dv_acc are all cleard, return it
|
||||
// Note: v loaded but no fence, ignore it.
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType* __restrict__>(smem_ptr_);
|
||||
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr0 = reinterpret_cast<QDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto q_lds_ptr1 = reinterpret_cast<QDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>());
|
||||
const auto ds_lds_ptr = reinterpret_cast<GemmDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
|
||||
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// V, HBM ->LDS ->Reg
|
||||
auto v_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<VDataType>(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
v_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, HBM -> LDS --trload-->Reg
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Pre-Load KV into Registers
|
||||
auto k_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor<Problem>());
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
auto k_reg_tensor = load_tile(k_lds_read_window);
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(k_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
|
||||
|
||||
auto v_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor<Problem>());
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM -->LDS
|
||||
auto q_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
q_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr0, Policy::template MakeQLdsWriteBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr0, Policy::template MakeQLdsReadBlockDescriptor<Problem>());
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK0>{}),
|
||||
q_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
|
||||
auto qt_lds_read_window =
|
||||
make_tile_window(q_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dO: HBM ->LDS ---load--> Reg
|
||||
// dOT: \-loadtr-> Reg
|
||||
auto do_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<OGradDataType>(
|
||||
do_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>());
|
||||
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr0, Policy::template MakeOGradLdsWriteBlockDescriptor<Problem>());
|
||||
auto do_lds_write_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr0, Policy::template MakeOGradLdsReadBlockDescriptor<Problem>());
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
do_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
|
||||
auto dot_lds_read_window =
|
||||
make_tile_window(do_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dS: Reg -> Reg -> LDS
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// transform it to make it from col-major to row-major; prepared for load_tile_transpose
|
||||
auto ds_lds_t = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem, true>());
|
||||
auto ds_lds_read_window =
|
||||
make_tile_window(ds_lds_t,
|
||||
make_tuple(number<kM0>{}, number<kK4>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// Bias: HBM ->Reg ->Reg ->LDS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
|
||||
randval_dram_block_window_tmp, seqlen_q_start);
|
||||
|
||||
// BiasGrad
|
||||
// Reg ->LDS ->Reg ->HBM
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto dbias_dram_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
|
||||
|
||||
auto dbias_lds_read_window =
|
||||
make_tile_window(bias_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
clear_tile(dv_acc);
|
||||
clear_tile(dk_acc);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
decltype(load_tile(q_lds_read_window)) q_reg_tensor;
|
||||
decltype(load_tile(lse_lds_read_window)) lse;
|
||||
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor;
|
||||
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next;
|
||||
decltype(load_tile(do_lds_read_window)) do_reg_tensor;
|
||||
decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor;
|
||||
decltype(load_tile(d_lds_read_window)) d;
|
||||
decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor;
|
||||
decltype(gemm_0.MakeCBlockTile()) s_acc, p;
|
||||
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
|
||||
decltype(gemm_4.MakeCBlockTile()) dq_acc;
|
||||
|
||||
decltype(load_tile(lse_dram_window)) lse_block_tile;
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
index_t i_total_bodys = 0;
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
const bool is_even = (i_total_bodys % 2 == 0);
|
||||
QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
|
||||
QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
|
||||
OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
|
||||
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
constexpr bool is_main_body = is_prologue && is_epilogue;
|
||||
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
move_tile_window(q_dram_window, {kM0, 0});
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
move_tile_window(lse_dram_window, {kM0});
|
||||
|
||||
do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
move_tile_window(do_dram_window, {kM0, 0});
|
||||
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 1, Q@K Gemm0
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
|
||||
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
async_load_tile(bias_lds_write_window, bias_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto p_spans = decltype(p)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
|
||||
else
|
||||
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
|
||||
}
|
||||
const auto p_gemm = [&]() { // dropout / type conversion
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) {
|
||||
return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
|
||||
},
|
||||
p);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(p);
|
||||
}
|
||||
}();
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
|
||||
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
}
|
||||
block_sync_lds();
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
store_tile(lse_lds_write_window, lse_block_tile);
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
|
||||
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = p[i_j_idx] >= 0;
|
||||
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbias = [&]() {
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
ds);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
move_tile_window(dbias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm3();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE7 SGrad@K^T Gemm4
|
||||
clear_tile(dq_acc);
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
|
||||
}
|
||||
});
|
||||
move_tile_window(ds_lds_read_window, {-kN0, 0});
|
||||
}
|
||||
block_sync_lds();
|
||||
if constexpr(is_prologue)
|
||||
{
|
||||
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
d = load_tile(d_lds_read_window);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// QGrad Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
}
|
||||
i_total_bodys += 1;
|
||||
};
|
||||
|
||||
main_body(std::true_type{}, std::false_type{});
|
||||
// Hot loop
|
||||
if(num_total_loop > 1)
|
||||
{
|
||||
do
|
||||
{
|
||||
main_body(std::true_type{}, std::true_type{});
|
||||
i_total_loops += 1;
|
||||
seqlen_q_step += kM0;
|
||||
} while(i_total_loops < num_total_loop - 1);
|
||||
}
|
||||
main_body(std::false_type{}, std::true_type{});
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -64,7 +64,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
|
||||
CK_TILE_DEVICE static constexpr auto GetPTOGradTBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::GemmDataType,
|
||||
@@ -84,7 +84,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
@@ -151,7 +156,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
|
||||
@@ -27,6 +27,7 @@ template <typename QDataType_,
|
||||
bool kIsDeterministic_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
bool kUseTrLoad_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdPipelineProblem
|
||||
{
|
||||
@@ -53,6 +54,7 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,7 +11,9 @@ namespace ck_tile {
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy,
|
||||
bool TransposeC_ = false>
|
||||
struct BlockGemmARegBRegCRegV1
|
||||
{
|
||||
private:
|
||||
@@ -44,8 +46,9 @@ struct BlockGemmARegBRegCRegV1
|
||||
};
|
||||
|
||||
public:
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
|
||||
@@ -131,6 +134,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
@@ -138,7 +142,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
c_distr_ys_major,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
@@ -152,7 +156,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
c_distr_ys_major,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
@@ -172,25 +176,19 @@ struct BlockGemmARegBRegCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode();
|
||||
|
||||
constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode();
|
||||
|
||||
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode();
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
@@ -219,7 +217,6 @@ struct BlockGemmARegBRegCRegV1
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
@@ -227,16 +224,16 @@ struct BlockGemmARegBRegCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::
|
||||
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
@@ -244,7 +241,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
@@ -254,6 +251,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
@@ -261,7 +259,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
c_distr_ys_major,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -277,7 +275,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
c_distr_ys_major,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
|
||||
Reference in New Issue
Block a user