[CK_TILE] FMHA BWD Decode Pipeline (#2643)

* Fix distr

* Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr

* decode 16x16 o2
This commit is contained in:
Yi DING
2025-08-12 17:02:52 +08:00
committed by GitHub
parent 352f87e684
commit 8e1eb0c1ee
11 changed files with 1051 additions and 165 deletions

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
#include <string>
#include <type_traits>
@@ -26,14 +27,22 @@
namespace ck_tile {
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
template <typename FmhaPipeline_,
typename KGradEpiloguePipeline_,
typename VGradEpiloguePipeline_,
typename QGradEpiloguePipeline_ = void>
struct FmhaBwdDQDKDVKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
using QGradEpiloguePipeline = ck_tile::remove_cvref_t<QGradEpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static constexpr bool kUseQrQtrDorPipeline =
ck_tile::fmha_bwd_qr_qtr_dor_pipeline_c<FmhaPipeline>;
static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
"QrQtrDorPipeline needs QGradEpiloguePipeline");
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
@@ -63,6 +72,8 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
#if defined(__gfx950__)
static constexpr bool kIsAvialable = true;
#else
@@ -128,7 +139,7 @@ struct FmhaBwdDQDKDVKernel
const void* lse_ptr;
const void* do_ptr;
const void* d_ptr;
void* dq_acc_ptr;
void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
void* dk_ptr;
void* dv_ptr;
@@ -335,7 +346,7 @@ struct FmhaBwdDQDKDVKernel
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
@@ -482,7 +493,7 @@ struct FmhaBwdDQDKDVKernel
}
}
if constexpr(kIsDeterministic)
if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
@@ -640,7 +651,9 @@ struct FmhaBwdDQDKDVKernel
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
return dim3(
ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
@@ -735,10 +748,9 @@ struct FmhaBwdDQDKDVKernel
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_k <= i_n0)
{
return;
}
if constexpr(!kUseQrQtrDorPipeline)
if(kargs.seqlen_k <= i_n0)
return;
}
else
{
@@ -786,12 +798,10 @@ struct FmhaBwdDQDKDVKernel
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do;
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
batch_offset_dk;
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
batch_offset_dv;
auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -868,8 +878,11 @@ struct FmhaBwdDQDKDVKernel
{0, 0});
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kIsDeterministic)
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kUseKSplit)
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
@@ -878,7 +891,7 @@ struct FmhaBwdDQDKDVKernel
batch_offset_dq_acc;
}();
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
memory_operation_enum::set, memory_operation_enum::atomic_add);
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
@@ -1063,25 +1076,6 @@ struct FmhaBwdDQDKDVKernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
dq_dram_window,
dbias_dram_window,
mask,
position_encoding,
kargs.raw_scale,
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
auto dk_dram = [&]() {
const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dk_ptr,
@@ -1119,9 +1113,56 @@ struct FmhaBwdDQDKDVKernel
dv_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
{i_n0, 0});
if constexpr(!kUseQrQtrDorPipeline)
{
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
dq_dram_window,
dbias_dram_window,
mask,
position_encoding,
kargs.raw_scale,
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
}
else
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
dq_dram_window,
dk_dram_window,
dv_dram_window,
dbias_dram_window,
QGradEpiloguePipeline{},
KGradEpiloguePipeline{},
VGradEpiloguePipeline{},
mask,
position_encoding,
kargs.raw_scale,
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
}
}
};