[CK_TILE] Use Persistent Scheduling for FMHA BWD Group Deterministic

This commit is contained in:
Ding, Yi
2026-04-08 03:38:17 -05:00
parent bedd60a568
commit 1a9404ac96

View File

@@ -39,6 +39,12 @@ struct FmhaBwdWorkspaceManager
// long_index_t dq_acc_offsets[batch]
// — per-batch offset array
// [OPTIONAL, only for deterministic group mode persistent]
// index_t prefix_batch[batch+1]
// — prefix sum of nhead * num_chunks[b] * seqlen_q[b] across batches
// index_t cu_start_ibatch[num_cus]
// — first batch index that overlaps each CU's workload interval
// GPU WORKSPACE BELOW (read & written by kernels):
// [OPTIONAL, only for !kUseQrQtrDorPipeline]
@@ -51,7 +57,7 @@ struct FmhaBwdWorkspaceManager
static constexpr size_t ALIGNMENT = 16;
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsSize(const int batch)
CK_TILE_HOST static size_t GetDqAccSplitsSize(const int batch)
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
@@ -59,28 +65,50 @@ struct FmhaBwdWorkspaceManager
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 1;
return integer_least_multiple(sizeof(index_t) * dqAccSplitsElems, ALIGNMENT);
}
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsSize(const int batch)
CK_TILE_HOST static size_t GetDqAccOffsetsSize(const int batch)
{
const auto dqAccOffsetsElems =
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 0;
return integer_least_multiple(sizeof(long_index_t) * dqAccOffsetsElems, ALIGNMENT);
}
CK_TILE_HOST static size_t GetPrefixBatchSize(const int batch)
{
if constexpr(kIsGroupMode && kIsDeterministic)
return integer_least_multiple(sizeof(index_t) * (batch + 1), ALIGNMENT);
return 0;
}
CK_TILE_HOST static size_t GetCuStartIbatchSize(const int num_cus)
{
if constexpr(kIsGroupMode && kIsDeterministic)
return integer_least_multiple(sizeof(index_t) * num_cus, ALIGNMENT);
return 0;
}
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetWorkspaceHostSize(const int batch)
CK_TILE_HOST static size_t GetWorkspaceHostSize(const int batch)
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch);
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch) +
GetPrefixBatchSize(batch) + GetCuStartIbatchSize(get_num_cus());
}
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsOffset(const int) { return 0; }
CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; }
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsOffset(const int batch)
CK_TILE_HOST static size_t GetDqAccOffsetsOffset(const int batch)
{
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch);
}
CK_TILE_HOST static size_t GetPrefixBatchOffset(const int batch)
{
return GetDqAccSplitsSize<false>(batch) + GetDqAccOffsetsSize(batch);
}
CK_TILE_HOST static size_t GetCuStartIbatchOffset(const int batch)
{
return GetPrefixBatchOffset(batch) + GetPrefixBatchSize(batch);
}
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST_DEVICE static size_t GetDqAccDataOffset(const int batch)
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
{
return GetWorkspaceHostSize<kUseQrQtrDorPipeline>(batch);
}
@@ -121,21 +149,64 @@ struct FmhaBwdWorkspaceManager
seqstart_qs[batch_size] * hdim_q;
}
else if constexpr(kIsGroupMode)
{ // deterministic group mode
{ // deterministic group mode (persistent)
// Step 1: compute prefix_batch and target_w using per-batch seqlens
const index_t num_cus = get_num_cus();
auto* prefix_batch = reinterpret_cast<index_t*>(reinterpret_cast<char*>(cpu_ws) +
GetDqAccSplitsSize<false>(batch_size) +
GetDqAccOffsetsSize(batch_size));
auto* cu_start_ibatch = reinterpret_cast<index_t*>(
reinterpret_cast<char*>(cpu_ws) + GetDqAccSplitsSize<false>(batch_size) +
GetDqAccOffsetsSize(batch_size) + GetPrefixBatchSize(batch_size));
prefix_batch[0] = 0;
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq;
}
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
// Step 2: compute nsplits[b] = per_batch_max_cus[b] (for dq_acc split dimension)
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0;
nsplits[b] = 1 + (rest_workload > 0 && target_w > 0
? integer_divide_ceil(rest_workload, target_w)
: 0);
}
// Step 3: compute per-batch dq_acc offsets (compact layout, depends on nsplits)
offsets[0] = 0;
index_t i = 0;
for(; i < batch_size - 1; ++i)
{
nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0);
offsets[i + 1] = offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
}
nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0);
return sizeof(AccDataType) *
(offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q);
const long_index_t dq_acc_elems =
offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
// Step 4: fill cu_start_ibatch via two-pointer scan O(batch + num_cus)
index_t cu_lo = 0;
for(index_t b = 0; b < batch_size; ++b)
{
const index_t cu_hi =
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
for(index_t c = cu_lo; c < cu_hi; ++c)
cu_start_ibatch[c] = b;
cu_lo = cu_hi;
}
for(index_t c = cu_lo; c < num_cus; ++c)
cu_start_ibatch[c] = batch_size; // sentinel: this CU has no work
return sizeof(AccDataType) * dq_acc_elems;
}
else // deterministic non-group mode (kUsePersistent)
else // deterministic batch mode (kUsePersistent)
{
const index_t dqdqkdv_workers = get_num_cus();
const index_t jobs_per_head = integer_divide_ceil(seqlen_k, kN0);
@@ -159,13 +230,14 @@ struct FmhaBwdWorkspaceManager
size_t host_ws_size)
{
constexpr bool NeedsZeroDqAcc = []() {
constexpr bool kUsePersistent =
!kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
// non-deterministic and persistent kernels use atomic-add to write dq
constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic;
// Persistent (batch and group): uses atomic_add → buffer must start at zero
// so that accumulated dq values are correct.
// Non-deterministic: uses atomic_add → buffer must start at zero.
if constexpr(kUsePersistent || !kIsDeterministic)
return true;
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases we need to zero out it first
// Non-persistent deterministic: uses set, but causal mask may skip some tiles
// leaving dq_acc slots unwritten — zero them out first.
return kHasMask;
}();
if(host_ws_size > 0)
@@ -228,8 +300,7 @@ struct FmhaBwdDQDKDVKernel
#else
static constexpr bool kIsAvailable = !kUseTrLoad;
#endif
static constexpr bool kUsePersistent =
kIsDeterministic && !kIsGroupMode && !kUseQrQtrDorPipeline;
static constexpr bool kUsePersistent = kIsDeterministic && !kUseQrQtrDorPipeline;
using WorkspaceManager = FmhaBwdWorkspaceManager<AccDataType, kIsGroupMode, kIsDeterministic>;
// clang-format off
@@ -452,7 +523,10 @@ struct FmhaBwdDQDKDVKernel
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t batch; // used for persistent kernel implementation
const ck_tile::index_t* nsplits_ptr; // points to nsplits[0] in workspace (batch mode)
const ck_tile::index_t* nsplits_ptr; // per-batch nsplits (group) or single scalar (batch)
// group mode persistent scheduling tables (read from CPU workspace by GPU):
const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1]
const ck_tile::index_t* cu_start_ibatch_ptr; // first batch for each CU, size [num_cus]
};
struct FmhaBwdBatchModeKargs
@@ -872,7 +946,15 @@ struct FmhaBwdDQDKDVKernel
}
if constexpr(kUsePersistent)
kargs.batch = batch;
{
kargs.batch = batch;
kargs.nsplits_ptr = reinterpret_cast<const ck_tile::index_t*>(
ws + WorkspaceManager::GetDqAccSplitsOffset(batch));
kargs.prefix_batch_ptr = reinterpret_cast<const ck_tile::index_t*>(
ws + WorkspaceManager::GetPrefixBatchOffset(batch));
kargs.cu_start_ibatch_ptr = reinterpret_cast<const ck_tile::index_t*>(
ws + WorkspaceManager::GetCuStartIbatchOffset(batch));
}
return kargs;
}
@@ -928,43 +1010,113 @@ struct FmhaBwdDQDKDVKernel
{
static_assert(!kUseQrQtrDorPipeline,
"Persistent kernel is not compatible with QR/QTR/DOR pipeline");
const index_t worker_id = blockIdx.x;
const index_t worker_num = gridDim.x;
if constexpr(!kIsGroupMode)
{
// Batch mode persistent: uniform seqlen_k across all batches
const index_t worker_id = blockIdx.x;
const index_t worker_num = gridDim.x;
const index_t jobs_per_head =
integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
const index_t total_heads = kargs.batch * kargs.nhead_q;
const index_t total_jobs = jobs_per_head * total_heads;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num);
const index_t jobs_per_head =
integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
const index_t total_heads = kargs.batch * kargs.nhead_q;
const index_t total_jobs = jobs_per_head * total_heads;
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num);
const index_t begin_job_id = worker_id * jobs_per_worker;
if(begin_job_id >= total_jobs)
return; // worker_id exceeds total jobs, exit early
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
const index_t begin_job_id = worker_id * jobs_per_worker;
if(begin_job_id >= total_jobs)
return; // worker_id exceeds total jobs, exit early
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
if constexpr(kHasMask == false)
return x;
else
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
};
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
if constexpr(kHasMask == false)
return x;
else
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
};
const auto n_splits = kargs.nsplits_ptr[0];
index_t job_id = begin_job_id;
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
do
{ // loop over jobs assigned to this worker
const index_t i_head_flatten = job_id / jobs_per_head;
const index_t i_tile_n_ = job_id % jobs_per_head;
const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head);
const index_t i_batch = i_head_flatten / kargs.nhead_q;
const index_t i_nhead = i_head_flatten % kargs.nhead_q;
const auto n_splits = kargs.nsplits_ptr[0];
index_t job_id = begin_job_id;
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
do
{ // loop over jobs assigned to this worker
const index_t i_head_flatten = job_id / jobs_per_head;
const index_t i_tile_n_ = job_id % jobs_per_head;
const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head);
const index_t i_batch = i_head_flatten / kargs.nhead_q;
const index_t i_nhead = i_head_flatten % kargs.nhead_q;
if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head
i_split = 0;
run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split, n_splits);
} while(++job_id < end_job_id);
if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head
i_split = 0;
run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split, n_splits);
} while(++job_id < end_job_id);
}
else
{
// Group mode persistent: variable seqlen per batch, dispatch via gist algo.
// Each CU independently determines its workload interval using prefix_batch.
const index_t cu_id = blockIdx.x;
const index_t num_cu = gridDim.x;
const index_t nbatch = kargs.batch;
// prefix_batch[nbatch] = total workload (nhead * sum(nc[b] * sq[b]))
const index_t total_w = kargs.prefix_batch_ptr[nbatch];
if(total_w == 0)
return;
const index_t target_w = integer_divide_ceil(total_w, num_cu);
const index_t w_lo = amd_wave_read_first_lane(cu_id * target_w);
const index_t w_hi = amd_wave_read_first_lane(
min(static_cast<index_t>((cu_id + 1) * target_w), total_w));
if(w_lo >= total_w)
return; // this CU has no work
for(index_t ibatch = kargs.cu_start_ibatch_ptr[cu_id]; ibatch < nbatch;
++ibatch)
{
const index_t pb = kargs.prefix_batch_ptr[ibatch];
if(pb >= w_hi)
break; // all remaining batches are past this CU's interval
// per-batch seqlen: prefer seqlen_ptr if available, else diff seqstart
const index_t sq = amd_wave_read_first_lane(
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[ibatch]
: (kargs.seqstart_q_ptr[ibatch + 1] -
kargs.seqstart_q_ptr[ibatch]));
const index_t sk = amd_wave_read_first_lane(
kargs.seqlen_k_ptr ? kargs.seqlen_k_ptr[ibatch]
: (kargs.seqstart_k_ptr[ibatch + 1] -
kargs.seqstart_k_ptr[ibatch]));
const index_t nc = integer_divide_ceil(sk, FmhaPipeline::kN0);
const index_t hw = nc * sq; // workload per (batch, head) pair
if(hw == 0)
continue;
const index_t nsplits_b =
amd_wave_read_first_lane(kargs.nsplits_ptr[ibatch]);
// first head whose interval overlaps [w_lo, w_hi)
const index_t head_start = max(static_cast<index_t>((w_lo - pb) / hw), 0);
for(index_t head_idx = head_start; head_idx < kargs.nhead_q; ++head_idx)
{
const index_t w_head = pb + head_idx * hw;
if(w_head >= w_hi)
return; // remaining heads are past the interval
// wc_start: workload offset of this CU's start relative to head start.
// Used for both isplit and c_start.
const index_t wc_start = max(static_cast<index_t>(w_lo - w_head), 0);
// isplit = rank of this CU among all CUs touching this head
const index_t isplit = integer_divide_ceil(wc_start, target_w);
const index_t c_start =
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
const index_t c_end = integer_divide_ceil(min(hw, w_hi - w_head), sq);
for(index_t chunk_idx = c_start; chunk_idx < c_end; ++chunk_idx)
run_(kargs, dim3(chunk_idx, head_idx, ibatch), isplit, nsplits_b);
}
}
}
}
}
}
@@ -1229,14 +1381,23 @@ struct FmhaBwdDQDKDVKernel
const long_index_t split_stride = kargs.seqlen_q * kargs.hdim_q;
const auto nsplits = [&]() {
if constexpr(!kIsGroupMode)
return n_splits;
return n_splits; // batch persistent: passed from nsplits_ptr[0]
else if constexpr(kUsePersistent)
return n_splits; // group persistent: passed from nsplits_ptr[ibatch]
else
return integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
return integer_divide_ceil(kargs.seqlen_k,
FmhaPipeline::kN0); // group non-persistent
}();
return batch_offset_dq_acc + (i_nhead_ * nsplits + i_split) * split_stride;
}
}();
// kUseKSplit && !kUsePersistent is true only for QrQtrDor+deterministic,
// which writes dq directly (not through dq_acc splits) — use 'set'.
// All other deterministic paths are persistent and use 'atomic_add':
// a single CU may process multiple chunks of the same (batch, head, isplit)
// sequentially, so contributions must accumulate rather than overwrite.
// Non-deterministic paths also use 'atomic_add' (kUseKSplit=false).
constexpr auto DstInMemOp = conditional_expr<(kUseKSplit && !kUsePersistent)>(
memory_operation_enum::set, memory_operation_enum::atomic_add);
const index_t stride_dq_acc = [&]() {
@@ -1879,7 +2040,7 @@ struct FmhaBwdConvertQGradKernel
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
static constexpr bool kUsePersistent = kIsDeterministic && !kIsGroupMode;
static constexpr bool kUsePersistent = kIsDeterministic;
using WorkspaceManager = FmhaBwdWorkspaceManager<AccDataType, kIsGroupMode, kIsDeterministic>;
// clang-format off