mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] FMHA BWD Decode Pipeline (#2643)
* Fix distr * Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr * decode 16x16 o2
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
@@ -26,14 +27,22 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
|
||||
template <typename FmhaPipeline_,
|
||||
typename KGradEpiloguePipeline_,
|
||||
typename VGradEpiloguePipeline_,
|
||||
typename QGradEpiloguePipeline_ = void>
|
||||
struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
|
||||
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
|
||||
using QGradEpiloguePipeline = ck_tile::remove_cvref_t<QGradEpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static constexpr bool kUseQrQtrDorPipeline =
|
||||
ck_tile::fmha_bwd_qr_qtr_dor_pipeline_c<FmhaPipeline>;
|
||||
static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
|
||||
"QrQtrDorPipeline needs QGradEpiloguePipeline");
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
@@ -63,6 +72,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
|
||||
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
|
||||
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
|
||||
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvialable = true;
|
||||
#else
|
||||
@@ -128,7 +139,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
const void* d_ptr;
|
||||
void* dq_acc_ptr;
|
||||
void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
|
||||
@@ -335,7 +346,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
@@ -482,7 +493,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
|
||||
{
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
}
|
||||
@@ -640,7 +651,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
return dim3(
|
||||
ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
|
||||
kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex()
|
||||
@@ -735,10 +748,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_k <= i_n0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
if(kargs.seqlen_k <= i_n0)
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -786,12 +798,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
|
||||
batch_offset_do;
|
||||
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
|
||||
batch_offset_dk;
|
||||
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
|
||||
batch_offset_dv;
|
||||
auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
|
||||
auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
|
||||
|
||||
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -868,8 +878,11 @@ struct FmhaBwdDQDKDVKernel
|
||||
{0, 0});
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kIsDeterministic)
|
||||
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
|
||||
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
|
||||
|
||||
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kUseKSplit)
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
@@ -878,7 +891,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_dq_acc;
|
||||
}();
|
||||
|
||||
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
|
||||
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
@@ -1063,25 +1076,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dbias_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
auto dk_dram = [&]() {
|
||||
const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
dk_ptr,
|
||||
@@ -1119,9 +1113,56 @@ struct FmhaBwdDQDKDVKernel
|
||||
dv_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
{i_n0, 0});
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
{
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dbias_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dk_dram_window,
|
||||
dv_dram_window,
|
||||
dbias_dram_window,
|
||||
QGradEpiloguePipeline{},
|
||||
KGradEpiloguePipeline{},
|
||||
VGradEpiloguePipeline{},
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user