mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#5174 (commit a358a21)
[CK_TILE] FMHA BWD Use Persistent Kernels in Deterministic Mode (#5174) ## Motivation This PR enables a persistent-kernel execution path for FMHA backward (dQ/dK/dV) in deterministic mode, adjusting how dQ accumulation is split, stored, and converted back to final gradients. ## Technical Details - Introduces a persistent-kernel grid mapping in deterministic mode and updates split-count calculation accordingly. - Extends kernel kargs to carry batch-related info needed for persistent scheduling and dQ conversion. - Refactors dQ store conditions and adds mask-type traits/utilities and runner logging updates. ## Test Plan - Jenkins [base](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/10/pipeline) - Jenkins [AITER](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/12/pipeline) - Jenkins [FMHA](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/11/pipeline) - local FA tests ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e2f5ab8000
commit
574c1c121a
@@ -79,6 +79,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
#else
|
||||
static constexpr bool kIsAvailable = !kUseTrLoad;
|
||||
#endif
|
||||
static constexpr bool kUsePersistent =
|
||||
kIsDeterministic && !kIsGroupMode && !kUseQrQtrDorPipeline;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -124,13 +126,43 @@ struct FmhaBwdDQDKDVKernel
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
CK_TILE_HOST static index_t GetDqAccSplits(index_t seqlen_k)
|
||||
CK_TILE_HOST static index_t
|
||||
GetDqAccSplits(index_t batch_size_, index_t nhead_, index_t seqlen_k_)
|
||||
{
|
||||
if constexpr(kIsDeterministic)
|
||||
return integer_divide_ceil(seqlen_k, FmhaPipeline::BlockFmhaShape::kN0);
|
||||
// Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent
|
||||
static constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode;
|
||||
if constexpr(kUsePersistent__)
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t jobs_per_head =
|
||||
integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0);
|
||||
const index_t total_jobs = batch_size_ * nhead_ * jobs_per_head;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
|
||||
if(jobs_per_head % jobs_per_worker == 0)
|
||||
return jobs_per_head / jobs_per_worker;
|
||||
else if(jobs_per_worker % jobs_per_head == 0)
|
||||
return 1;
|
||||
else
|
||||
return 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
|
||||
}
|
||||
else if constexpr(kIsDeterministic)
|
||||
return integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0);
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
|
||||
{
|
||||
// Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent
|
||||
constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode;
|
||||
|
||||
// non-deterministic adn persistent kernels use atomic-add to write dq
|
||||
if constexpr(kUsePersistent__ || !kIsDeterministic)
|
||||
return true;
|
||||
|
||||
// Some block may be skipped with causal mask and dq are not set to zeros
|
||||
// In these cases we need to zero out it first
|
||||
return kHasMask;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
|
||||
// arg
|
||||
@@ -282,6 +314,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
struct FmhaBwdDeterministicKargs
|
||||
{
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
ck_tile::index_t batch; // used for persistent kernel implementation
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeKargs
|
||||
@@ -362,6 +395,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -507,9 +541,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
|
||||
{
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
}
|
||||
|
||||
if constexpr(kUsePersistent)
|
||||
kargs.batch = batch;
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -534,6 +569,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const void* seqlen_k_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
const void* cu_seqlen_k_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -659,9 +695,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
}
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
}
|
||||
if constexpr(kUsePersistent)
|
||||
kargs.batch = batch;
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -669,19 +705,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
return dim3(
|
||||
kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex()
|
||||
{
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
const index_t jobs_per_head =
|
||||
kUseQrQtrDorPipeline ? 1 : integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0);
|
||||
if constexpr(kUsePersistent)
|
||||
return dim3(get_num_cus(), 1, 1);
|
||||
else
|
||||
return dim3(jobs_per_head, nhead_, batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static dim3 BlockSize()
|
||||
@@ -706,16 +735,64 @@ struct FmhaBwdDQDKDVKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if constexpr(kIsAvailable)
|
||||
run_(std::move(kargs));
|
||||
{
|
||||
if constexpr(!kUsePersistent)
|
||||
{
|
||||
run_(std::move(kargs), blockIdx, blockIdx.x);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(!kUseQrQtrDorPipeline,
|
||||
"Persistent kernel is not compatible with QR/QTR/DOR pipeline");
|
||||
const index_t worker_id = blockIdx.x;
|
||||
const index_t worker_num = gridDim.x;
|
||||
|
||||
const index_t jobs_per_head =
|
||||
integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
|
||||
const index_t total_heads = kargs.batch * kargs.num_head_q;
|
||||
const index_t total_jobs = jobs_per_head * total_heads;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num);
|
||||
|
||||
const index_t begin_job_id = worker_id * jobs_per_worker;
|
||||
if(begin_job_id >= total_jobs)
|
||||
return; // worker_id exceeds total jobs, exit early
|
||||
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
|
||||
|
||||
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
|
||||
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
|
||||
if constexpr(kHasMask == false)
|
||||
return x;
|
||||
else
|
||||
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
|
||||
};
|
||||
|
||||
index_t job_id = begin_job_id;
|
||||
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
|
||||
do
|
||||
{ // loop over jobs assigned to this worker
|
||||
const index_t i_head_flatten = job_id / jobs_per_head;
|
||||
const index_t i_tile_n_ = job_id % jobs_per_head;
|
||||
const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head);
|
||||
const index_t i_batch = i_head_flatten / kargs.num_head_q;
|
||||
const index_t i_nhead = i_head_flatten % kargs.num_head_q;
|
||||
|
||||
if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head
|
||||
i_split = 0;
|
||||
run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split);
|
||||
} while(++job_id < end_job_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void run_(Kargs kargs) const
|
||||
CK_TILE_DEVICE void run_(Kargs kargs, const dim3& tile_index, const index_t i_split) const
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
|
||||
const index_t i_tile_n = tile_index.x;
|
||||
const index_t i_nhead = tile_index.y;
|
||||
const index_t i_batch = tile_index.z;
|
||||
|
||||
const index_t i_n0 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN0);
|
||||
|
||||
@@ -931,21 +1008,21 @@ struct FmhaBwdDQDKDVKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
{0, 0});
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
auto dq_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
|
||||
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
|
||||
|
||||
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kUseKSplit)
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_split) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
else
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
}();
|
||||
|
||||
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
|
||||
constexpr auto DstInMemOp = conditional_expr<(kUseKSplit && !kUsePersistent)>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
@@ -1528,6 +1605,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
|
||||
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
|
||||
static constexpr bool kUsePersistent = kIsDeterministic && !kIsGroupMode;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -1586,7 +1664,10 @@ struct FmhaBwdConvertQGradKernel
|
||||
|
||||
struct FmhaBwdConvertQGradDeterministicKargs
|
||||
{
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
index_t split_stride_dq_acc = 0;
|
||||
index_t dqdqkdv_workers = 0; // 0 for not using persistent kernel
|
||||
index_t batch_size = 0; // for nsplits calc of persistent kernel
|
||||
index_t nhead = 0; // for nsplits calc of persistent kernel
|
||||
};
|
||||
|
||||
struct FmhaBwdConvertQGradBatchModeKargs
|
||||
@@ -1630,7 +1711,9 @@ struct FmhaBwdConvertQGradKernel
|
||||
ck_tile::long_index_t nhead_stride_dq_acc,
|
||||
ck_tile::index_t batch_stride_dq,
|
||||
ck_tile::long_index_t batch_stride_dq_acc,
|
||||
ck_tile::index_t split_stride_dq_acc)
|
||||
ck_tile::index_t split_stride_dq_acc,
|
||||
ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead)
|
||||
{
|
||||
Kargs kargs{{dq_acc_ptr,
|
||||
dq_ptr,
|
||||
@@ -1648,6 +1731,12 @@ struct FmhaBwdConvertQGradKernel
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
if constexpr(kUsePersistent)
|
||||
{
|
||||
kargs.dqdqkdv_workers = get_num_cus();
|
||||
kargs.batch_size = batch_size;
|
||||
kargs.nhead = nhead;
|
||||
}
|
||||
}
|
||||
|
||||
return kargs;
|
||||
@@ -1783,6 +1872,27 @@ struct FmhaBwdConvertQGradKernel
|
||||
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
|
||||
batch_offset_dq;
|
||||
const index_t nsplits = [&]() {
|
||||
const index_t jobs_per_head = integer_divide_ceil(kargs.seqlen_k, kN0);
|
||||
if constexpr(!kIsDeterministic)
|
||||
return 1;
|
||||
else if constexpr(!kUsePersistent)
|
||||
return jobs_per_head;
|
||||
else
|
||||
{
|
||||
const index_t total_heads = kargs.batch_size * kargs.nhead;
|
||||
const index_t total_jobs = jobs_per_head * total_heads;
|
||||
const index_t jobs_per_worker =
|
||||
integer_divide_ceil(total_jobs, kargs.dqdqkdv_workers);
|
||||
const index_t i_head_flatten = i_batch * kargs.nhead + i_nhead;
|
||||
|
||||
const index_t i_job_start = jobs_per_head * i_head_flatten;
|
||||
const index_t begin_worker_id = i_job_start / jobs_per_worker;
|
||||
const index_t end_worker_id = // inclusive
|
||||
(i_job_start + jobs_per_head - 1) / jobs_per_worker;
|
||||
return end_worker_id - begin_worker_id + 1;
|
||||
}
|
||||
}();
|
||||
|
||||
// dQAcc/dQ DRAM and DRAM window
|
||||
const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
|
||||
@@ -1793,8 +1903,6 @@ struct FmhaBwdConvertQGradKernel
|
||||
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
|
||||
batch_offset_dq_acc;
|
||||
|
||||
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
|
||||
|
||||
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
|
||||
@@ -1856,7 +1964,6 @@ struct FmhaBwdConvertQGradKernel
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
|
||||
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -753,7 +753,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
if constexpr(kIsDeterministic)
|
||||
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
|
||||
@@ -789,7 +789,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
if constexpr(kIsDeterministic)
|
||||
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
@@ -1034,7 +1035,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
|
||||
@@ -764,7 +764,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
}
|
||||
if constexpr(kIsDeterministic)
|
||||
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user