[CK_TILE] FMHA BWD Use Persistent Kernels in Deterministic Mode (#5174)

## Motivation
This PR enables a persistent-kernel execution path for FMHA backward
(dQ/dK/dV) in deterministic mode, adjusting how dQ accumulation is
split, stored, and converted back to final gradients.

## Technical Details
- Introduces a persistent-kernel grid mapping in deterministic mode and
updates split-count calculation accordingly.
- Extends kernel kargs to carry batch-related info needed for persistent
scheduling and dQ conversion.
- Refactors dQ store conditions and adds mask-type traits/utilities and
runner logging updates.

## Test Plan
- Jenkins
[base](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/10/pipeline)
- Jenkins
[AITER](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/12/pipeline)
- Jenkins
[FMHA](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/11/pipeline)
- local FA tests

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yi DING
2026-03-13 14:13:32 +08:00
committed by GitHub
parent 5d2cbd1117
commit f6bfcad437
7 changed files with 175 additions and 49 deletions

View File

@@ -169,10 +169,17 @@ int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
}}
template <>
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(ck_tile::index_t seqlen_k)
int fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>(const fmha_bwd_traits& t)
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::GetDqAccSplits(seqlen_k);
return k_::GetDqAccSplits(t.batch, t.nhead_q, t.max_seqlen_k);
}}
template <>
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_{F_idx}, {F_arch.tag}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::NeedsZeroDqAcc();
}}
template <>
@@ -192,6 +199,7 @@ fmha_bwd_launcher::fmha_bwd_launcher(const fmha_bwd_traits& t){{
{F_launcher}
run = [](fmha_bwd_args, const ck_tile::stream_config&) {{ return -1.0f; }};
dq_acc_splits = 1;
needs_zero_dq_acc = false;
}}
@@ -231,7 +239,8 @@ FMHA_BWD_API_INNER_DISPATCH_LAUNCHER = """
run = [](fmha_bwd_args a, const ck_tile::stream_config& s) {{
return fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>, {F_arch.tag}>(s, a);
}};
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t.max_seqlen_k);
dq_acc_splits = fmha_bwd_dq_dk_dv_dq_acc_splits_<dq_dk_dv_trait_, {F_arch.tag}>(t);
needs_zero_dq_acc = fmha_bwd_dq_dk_dv_needs_zero_dq_acc_<dq_dk_dv_trait_, {F_arch.tag}>();
return;
}}
"""
@@ -447,7 +456,7 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9):
results = KernelComponentFactoryGfx9.get_dq_dk_dv_tiles(dtype, tr_load)
if dtype in ["fp16", "bf16"] and tr_load == "t":
results.extend([
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 32, 256, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
@@ -823,7 +832,7 @@ class FmhaBwdApiTrait:
@property
def extra_cond(self) -> str:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128:
return " && (t.seqlen_k <= 256)"
else:
return ""

View File

@@ -251,6 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.batch,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -300,6 +301,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.batch,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -429,7 +431,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
args.split_stride_dq_acc,
args.batch,
args.nhead_q);
}
}();
@@ -465,8 +469,11 @@ template <typename Traits_, typename Arch = void>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_maxq_();
struct fmha_bwd_traits;
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_dq_acc_splits_(ck_tile::index_t seqlen_k);
int fmha_bwd_dq_dk_dv_dq_acc_splits_(const fmha_bwd_traits& t);
template <typename Traits_, typename Arch = void>
bool fmha_bwd_dq_dk_dv_needs_zero_dq_acc_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
@@ -569,6 +576,7 @@ struct fmha_bwd_launcher
{
std::function<float(fmha_bwd_args, const ck_tile::stream_config&)> run{};
ck_tile::index_t dq_acc_splits{0};
bool needs_zero_dq_acc{true};
fmha_bwd_launcher(const fmha_bwd_traits&);

View File

@@ -416,9 +416,10 @@ bwd_result fmha_bwd_run(mode_enum mode,
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
<< (deterministic ? std::string(", workspace:") +
std::to_string(workspace_size_in_megabytes) + "MiB"
: "")
<< (deterministic
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
"MiB|" + std::to_string(nsplits) + "splits"
: "")
<< ", mask:" << mask << std::flush;
auto fmha_args = [&]() {
@@ -842,10 +843,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
lse_buf.ToDevice(lse_host.data());
dbias_buf.SetZero();
// non-deterministic kernels use atomic add to write dq
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases thus we need to zero out it first
if(!deterministic || mask.type != mask_enum::no_mask)
if(launcher.needs_zero_dq_acc)
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};

View File

@@ -79,6 +79,8 @@ struct FmhaBwdDQDKDVKernel
#else
static constexpr bool kIsAvailable = !kUseTrLoad;
#endif
static constexpr bool kUsePersistent =
kIsDeterministic && !kIsGroupMode && !kUseQrQtrDorPipeline;
// clang-format off
template <typename T> struct t2s;
@@ -124,13 +126,43 @@ struct FmhaBwdDQDKDVKernel
#undef _TS_
// clang-format on
}
CK_TILE_HOST static index_t GetDqAccSplits(index_t seqlen_k)
CK_TILE_HOST static index_t
GetDqAccSplits(index_t batch_size_, index_t nhead_, index_t seqlen_k_)
{
if constexpr(kIsDeterministic)
return integer_divide_ceil(seqlen_k, FmhaPipeline::BlockFmhaShape::kN0);
// Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent
static constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode;
if constexpr(kUsePersistent__)
{
const index_t dqdqkdv_workers = get_num_cus();
const index_t jobs_per_head =
integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0);
const index_t total_jobs = batch_size_ * nhead_ * jobs_per_head;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, dqdqkdv_workers);
if(jobs_per_head % jobs_per_worker == 0)
return jobs_per_head / jobs_per_worker;
else if(jobs_per_worker % jobs_per_head == 0)
return 1;
else
return 1 + integer_divide_ceil(jobs_per_head - 1, jobs_per_worker);
}
else if constexpr(kIsDeterministic)
return integer_divide_ceil(seqlen_k_, FmhaPipeline::BlockFmhaShape::kN0);
else
return 1;
}
CK_TILE_HOST static constexpr bool NeedsZeroDqAcc()
{
// Be consistent with convert_dq kernel, though qrqtrdor pipeline doesn't use persistent
constexpr bool kUsePersistent__ = kIsDeterministic && !kIsGroupMode;
// non-deterministic adn persistent kernels use atomic-add to write dq
if constexpr(kUsePersistent__ || !kIsDeterministic)
return true;
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases we need to zero out it first
return kHasMask;
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
@@ -282,6 +314,7 @@ struct FmhaBwdDQDKDVKernel
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
ck_tile::index_t batch; // used for persistent kernel implementation
};
struct FmhaBwdBatchModeKargs
@@ -362,6 +395,7 @@ struct FmhaBwdDQDKDVKernel
void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t batch,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -507,9 +541,10 @@ struct FmhaBwdDQDKDVKernel
}
if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
if constexpr(kUsePersistent)
kargs.batch = batch;
return kargs;
}
@@ -534,6 +569,7 @@ struct FmhaBwdDQDKDVKernel
const void* seqlen_k_ptr,
const void* cu_seqlen_q_ptr,
const void* cu_seqlen_k_ptr,
ck_tile::index_t batch,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -659,9 +695,9 @@ struct FmhaBwdDQDKDVKernel
}
}
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
if constexpr(kUsePersistent)
kargs.batch = batch;
return kargs;
}
@@ -669,19 +705,12 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
return dim3(
kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
const index_t jobs_per_head =
kUseQrQtrDorPipeline ? 1 : integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0);
if constexpr(kUsePersistent)
return dim3(get_num_cus(), 1, 1);
else
return dim3(jobs_per_head, nhead_, batch_size_);
}
CK_TILE_HOST static dim3 BlockSize()
@@ -706,16 +735,64 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(kIsAvailable)
run_(std::move(kargs));
{
if constexpr(!kUsePersistent)
{
run_(std::move(kargs), blockIdx, blockIdx.x);
}
else
{
static_assert(!kUseQrQtrDorPipeline,
"Persistent kernel is not compatible with QR/QTR/DOR pipeline");
const index_t worker_id = blockIdx.x;
const index_t worker_num = gridDim.x;
const index_t jobs_per_head =
integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
const index_t total_heads = kargs.batch * kargs.num_head_q;
const index_t total_jobs = jobs_per_head * total_heads;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num);
const index_t begin_job_id = worker_id * jobs_per_worker;
if(begin_job_id >= total_jobs)
return; // worker_id exceeds total jobs, exit early
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
if constexpr(kHasMask == false)
return x;
else
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
};
index_t job_id = begin_job_id;
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
do
{ // loop over jobs assigned to this worker
const index_t i_head_flatten = job_id / jobs_per_head;
const index_t i_tile_n_ = job_id % jobs_per_head;
const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head);
const index_t i_batch = i_head_flatten / kargs.num_head_q;
const index_t i_nhead = i_head_flatten % kargs.num_head_q;
if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head
i_split = 0;
run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split);
} while(++job_id < end_job_id);
}
}
}
CK_TILE_DEVICE void run_(Kargs kargs) const
CK_TILE_DEVICE void run_(Kargs kargs, const dim3& tile_index, const index_t i_split) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
const index_t i_tile_n = tile_index.x;
const index_t i_nhead = tile_index.y;
const index_t i_batch = tile_index.z;
const index_t i_n0 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN0);
@@ -931,21 +1008,21 @@ struct FmhaBwdDQDKDVKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
{0, 0});
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
auto dq_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kUseKSplit)
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
static_cast<long_index_t>(i_split) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
else
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
}();
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
constexpr auto DstInMemOp = conditional_expr<(kUseKSplit && !kUsePersistent)>(
memory_operation_enum::set, memory_operation_enum::atomic_add);
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
@@ -1528,6 +1605,7 @@ struct FmhaBwdConvertQGradKernel
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
static constexpr bool kUsePersistent = kIsDeterministic && !kIsGroupMode;
// clang-format off
template <typename T> struct t2s;
@@ -1586,7 +1664,10 @@ struct FmhaBwdConvertQGradKernel
struct FmhaBwdConvertQGradDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
index_t split_stride_dq_acc = 0;
index_t dqdqkdv_workers = 0; // 0 for not using persistent kernel
index_t batch_size = 0; // for nsplits calc of persistent kernel
index_t nhead = 0; // for nsplits calc of persistent kernel
};
struct FmhaBwdConvertQGradBatchModeKargs
@@ -1630,7 +1711,9 @@ struct FmhaBwdConvertQGradKernel
ck_tile::long_index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq,
ck_tile::long_index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t batch_size,
ck_tile::index_t nhead)
{
Kargs kargs{{dq_acc_ptr,
dq_ptr,
@@ -1648,6 +1731,12 @@ struct FmhaBwdConvertQGradKernel
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
if constexpr(kUsePersistent)
{
kargs.dqdqkdv_workers = get_num_cus();
kargs.batch_size = batch_size;
kargs.nhead = nhead;
}
}
return kargs;
@@ -1783,6 +1872,27 @@ struct FmhaBwdConvertQGradKernel
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
batch_offset_dq;
const index_t nsplits = [&]() {
const index_t jobs_per_head = integer_divide_ceil(kargs.seqlen_k, kN0);
if constexpr(!kIsDeterministic)
return 1;
else if constexpr(!kUsePersistent)
return jobs_per_head;
else
{
const index_t total_heads = kargs.batch_size * kargs.nhead;
const index_t total_jobs = jobs_per_head * total_heads;
const index_t jobs_per_worker =
integer_divide_ceil(total_jobs, kargs.dqdqkdv_workers);
const index_t i_head_flatten = i_batch * kargs.nhead + i_nhead;
const index_t i_job_start = jobs_per_head * i_head_flatten;
const index_t begin_worker_id = i_job_start / jobs_per_worker;
const index_t end_worker_id = // inclusive
(i_job_start + jobs_per_head - 1) / jobs_per_worker;
return end_worker_id - begin_worker_id + 1;
}
}();
// dQAcc/dQ DRAM and DRAM window
const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
@@ -1793,8 +1903,6 @@ struct FmhaBwdConvertQGradKernel
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
@@ -1856,7 +1964,6 @@ struct FmhaBwdConvertQGradKernel
if constexpr(kIsDeterministic)
{
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
}
else

View File

@@ -753,7 +753,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
if constexpr(kIsDeterministic)
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(dq_dram_window, dq_acc);
}

View File

@@ -789,7 +789,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
if constexpr(kIsDeterministic)
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(dq_dram_window, dq_acc);
}
@@ -1034,7 +1035,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
if constexpr(kIsDeterministic)
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(dq_dram_window, dq_acc);
}

View File

@@ -764,7 +764,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
if constexpr(kIsDeterministic)
if constexpr(decltype(dq_dram_window)::BottomTensorView::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(dq_dram_window, dq_acc);
}