mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Use Persistent Scheduling for FMHA BWD Group Deterministic
This commit is contained in:
@@ -39,6 +39,12 @@ struct FmhaBwdWorkspaceManager
|
||||
// long_index_t dq_acc_offsets[batch]
|
||||
// — per-batch offset array
|
||||
|
||||
// [OPTIONAL, only for deterministic group mode persistent]
|
||||
// index_t prefix_batch[batch+1]
|
||||
// — prefix sum of nhead * num_chunks[b] * seqlen_q[b] across batches
|
||||
// index_t cu_start_ibatch[num_cus]
|
||||
// — first batch index that overlaps each CU's workload interval
|
||||
|
||||
// GPU WORKSPACE BELOW (read & written by kernels):
|
||||
|
||||
// [OPTIONAL, only for !kUseQrQtrDorPipeline]
|
||||
@@ -51,7 +57,7 @@ struct FmhaBwdWorkspaceManager
|
||||
static constexpr size_t ALIGNMENT = 16;
|
||||
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsSize(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccSplitsSize(const int batch)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
@@ -59,28 +65,50 @@ struct FmhaBwdWorkspaceManager
|
||||
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 1;
|
||||
return integer_least_multiple(sizeof(index_t) * dqAccSplitsElems, ALIGNMENT);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsSize(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccOffsetsSize(const int batch)
|
||||
{
|
||||
const auto dqAccOffsetsElems =
|
||||
(kIsGroupMode && kIsDeterministic) ? static_cast<size_t>(batch) : 0;
|
||||
return integer_least_multiple(sizeof(long_index_t) * dqAccOffsetsElems, ALIGNMENT);
|
||||
}
|
||||
CK_TILE_HOST static size_t GetPrefixBatchSize(const int batch)
|
||||
{
|
||||
if constexpr(kIsGroupMode && kIsDeterministic)
|
||||
return integer_least_multiple(sizeof(index_t) * (batch + 1), ALIGNMENT);
|
||||
return 0;
|
||||
}
|
||||
CK_TILE_HOST static size_t GetCuStartIbatchSize(const int num_cus)
|
||||
{
|
||||
if constexpr(kIsGroupMode && kIsDeterministic)
|
||||
return integer_least_multiple(sizeof(index_t) * num_cus, ALIGNMENT);
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetWorkspaceHostSize(const int batch)
|
||||
CK_TILE_HOST static size_t GetWorkspaceHostSize(const int batch)
|
||||
{
|
||||
if constexpr(kUseQrQtrDorPipeline)
|
||||
return 0;
|
||||
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch);
|
||||
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch) +
|
||||
GetPrefixBatchSize(batch) + GetCuStartIbatchSize(get_num_cus());
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsOffset(const int) { return 0; }
|
||||
CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; }
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsOffset(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccOffsetsOffset(const int batch)
|
||||
{
|
||||
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch);
|
||||
}
|
||||
CK_TILE_HOST static size_t GetPrefixBatchOffset(const int batch)
|
||||
{
|
||||
return GetDqAccSplitsSize<false>(batch) + GetDqAccOffsetsSize(batch);
|
||||
}
|
||||
CK_TILE_HOST static size_t GetCuStartIbatchOffset(const int batch)
|
||||
{
|
||||
return GetPrefixBatchOffset(batch) + GetPrefixBatchSize(batch);
|
||||
}
|
||||
template <bool kUseQrQtrDorPipeline>
|
||||
CK_TILE_HOST_DEVICE static size_t GetDqAccDataOffset(const int batch)
|
||||
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
|
||||
{
|
||||
return GetWorkspaceHostSize<kUseQrQtrDorPipeline>(batch);
|
||||
}
|
||||
@@ -121,21 +149,64 @@ struct FmhaBwdWorkspaceManager
|
||||
seqstart_qs[batch_size] * hdim_q;
|
||||
}
|
||||
else if constexpr(kIsGroupMode)
|
||||
{ // deterministic group mode
|
||||
{ // deterministic group mode (persistent)
|
||||
// Step 1: compute prefix_batch and target_w using per-batch seqlens
|
||||
const index_t num_cus = get_num_cus();
|
||||
auto* prefix_batch = reinterpret_cast<index_t*>(reinterpret_cast<char*>(cpu_ws) +
|
||||
GetDqAccSplitsSize<false>(batch_size) +
|
||||
GetDqAccOffsetsSize(batch_size));
|
||||
auto* cu_start_ibatch = reinterpret_cast<index_t*>(
|
||||
reinterpret_cast<char*>(cpu_ws) + GetDqAccSplitsSize<false>(batch_size) +
|
||||
GetDqAccOffsetsSize(batch_size) + GetPrefixBatchSize(batch_size));
|
||||
|
||||
prefix_batch[0] = 0;
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq;
|
||||
}
|
||||
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
|
||||
|
||||
// Step 2: compute nsplits[b] = per_batch_max_cus[b] (for dq_acc split dimension)
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
|
||||
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
|
||||
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0;
|
||||
nsplits[b] = 1 + (rest_workload > 0 && target_w > 0
|
||||
? integer_divide_ceil(rest_workload, target_w)
|
||||
: 0);
|
||||
}
|
||||
|
||||
// Step 3: compute per-batch dq_acc offsets (compact layout, depends on nsplits)
|
||||
offsets[0] = 0;
|
||||
index_t i = 0;
|
||||
for(; i < batch_size - 1; ++i)
|
||||
{
|
||||
nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0);
|
||||
offsets[i + 1] = offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
|
||||
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
|
||||
}
|
||||
nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0);
|
||||
return sizeof(AccDataType) *
|
||||
(offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
|
||||
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q);
|
||||
const long_index_t dq_acc_elems =
|
||||
offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
|
||||
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
|
||||
|
||||
// Step 4: fill cu_start_ibatch via two-pointer scan O(batch + num_cus)
|
||||
index_t cu_lo = 0;
|
||||
for(index_t b = 0; b < batch_size; ++b)
|
||||
{
|
||||
const index_t cu_hi =
|
||||
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
|
||||
for(index_t c = cu_lo; c < cu_hi; ++c)
|
||||
cu_start_ibatch[c] = b;
|
||||
cu_lo = cu_hi;
|
||||
}
|
||||
for(index_t c = cu_lo; c < num_cus; ++c)
|
||||
cu_start_ibatch[c] = batch_size; // sentinel: this CU has no work
|
||||
|
||||
return sizeof(AccDataType) * dq_acc_elems;
|
||||
}
|
||||
else // deterministic non-group mode (kUsePersistent)
|
||||
else // deterministic batch mode (kUsePersistent)
|
||||
{
|
||||
const index_t dqdqkdv_workers = get_num_cus();
|
||||
const index_t jobs_per_head = integer_divide_ceil(seqlen_k, kN0);
|
||||
@@ -159,13 +230,14 @@ struct FmhaBwdWorkspaceManager
|
||||
size_t host_ws_size)
|
||||
{
|
||||
constexpr bool NeedsZeroDqAcc = []() {
|
||||
constexpr bool kUsePersistent =
|
||||
!kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode;
|
||||
// non-deterministic and persistent kernels use atomic-add to write dq
|
||||
constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic;
|
||||
// Persistent (batch and group): uses atomic_add → buffer must start at zero
|
||||
// so that accumulated dq values are correct.
|
||||
// Non-deterministic: uses atomic_add → buffer must start at zero.
|
||||
if constexpr(kUsePersistent || !kIsDeterministic)
|
||||
return true;
|
||||
// Some block may be skipped with causal mask and dq are not set to zeros
|
||||
// In these cases we need to zero out it first
|
||||
// Non-persistent deterministic: uses set, but causal mask may skip some tiles
|
||||
// leaving dq_acc slots unwritten — zero them out first.
|
||||
return kHasMask;
|
||||
}();
|
||||
if(host_ws_size > 0)
|
||||
@@ -228,8 +300,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
#else
|
||||
static constexpr bool kIsAvailable = !kUseTrLoad;
|
||||
#endif
|
||||
static constexpr bool kUsePersistent =
|
||||
kIsDeterministic && !kIsGroupMode && !kUseQrQtrDorPipeline;
|
||||
static constexpr bool kUsePersistent = kIsDeterministic && !kUseQrQtrDorPipeline;
|
||||
using WorkspaceManager = FmhaBwdWorkspaceManager<AccDataType, kIsGroupMode, kIsDeterministic>;
|
||||
|
||||
// clang-format off
|
||||
@@ -452,7 +523,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
struct FmhaBwdDeterministicKargs
|
||||
{
|
||||
ck_tile::index_t batch; // used for persistent kernel implementation
|
||||
const ck_tile::index_t* nsplits_ptr; // points to nsplits[0] in workspace (batch mode)
|
||||
const ck_tile::index_t* nsplits_ptr; // per-batch nsplits (group) or single scalar (batch)
|
||||
// group mode persistent scheduling tables (read from CPU workspace by GPU):
|
||||
const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1]
|
||||
const ck_tile::index_t* cu_start_ibatch_ptr; // first batch for each CU, size [num_cus]
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeKargs
|
||||
@@ -872,7 +946,15 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
|
||||
if constexpr(kUsePersistent)
|
||||
kargs.batch = batch;
|
||||
{
|
||||
kargs.batch = batch;
|
||||
kargs.nsplits_ptr = reinterpret_cast<const ck_tile::index_t*>(
|
||||
ws + WorkspaceManager::GetDqAccSplitsOffset(batch));
|
||||
kargs.prefix_batch_ptr = reinterpret_cast<const ck_tile::index_t*>(
|
||||
ws + WorkspaceManager::GetPrefixBatchOffset(batch));
|
||||
kargs.cu_start_ibatch_ptr = reinterpret_cast<const ck_tile::index_t*>(
|
||||
ws + WorkspaceManager::GetCuStartIbatchOffset(batch));
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -928,43 +1010,113 @@ struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
static_assert(!kUseQrQtrDorPipeline,
|
||||
"Persistent kernel is not compatible with QR/QTR/DOR pipeline");
|
||||
const index_t worker_id = blockIdx.x;
|
||||
const index_t worker_num = gridDim.x;
|
||||
if constexpr(!kIsGroupMode)
|
||||
{
|
||||
// Batch mode persistent: uniform seqlen_k across all batches
|
||||
const index_t worker_id = blockIdx.x;
|
||||
const index_t worker_num = gridDim.x;
|
||||
|
||||
const index_t jobs_per_head =
|
||||
integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
|
||||
const index_t total_heads = kargs.batch * kargs.nhead_q;
|
||||
const index_t total_jobs = jobs_per_head * total_heads;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num);
|
||||
const index_t jobs_per_head =
|
||||
integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
|
||||
const index_t total_heads = kargs.batch * kargs.nhead_q;
|
||||
const index_t total_jobs = jobs_per_head * total_heads;
|
||||
const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num);
|
||||
|
||||
const index_t begin_job_id = worker_id * jobs_per_worker;
|
||||
if(begin_job_id >= total_jobs)
|
||||
return; // worker_id exceeds total jobs, exit early
|
||||
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
|
||||
const index_t begin_job_id = worker_id * jobs_per_worker;
|
||||
if(begin_job_id >= total_jobs)
|
||||
return; // worker_id exceeds total jobs, exit early
|
||||
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
|
||||
|
||||
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
|
||||
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
|
||||
if constexpr(kHasMask == false)
|
||||
return x;
|
||||
else
|
||||
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
|
||||
};
|
||||
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
|
||||
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
|
||||
if constexpr(kHasMask == false)
|
||||
return x;
|
||||
else
|
||||
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
|
||||
};
|
||||
|
||||
const auto n_splits = kargs.nsplits_ptr[0];
|
||||
index_t job_id = begin_job_id;
|
||||
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
|
||||
do
|
||||
{ // loop over jobs assigned to this worker
|
||||
const index_t i_head_flatten = job_id / jobs_per_head;
|
||||
const index_t i_tile_n_ = job_id % jobs_per_head;
|
||||
const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head);
|
||||
const index_t i_batch = i_head_flatten / kargs.nhead_q;
|
||||
const index_t i_nhead = i_head_flatten % kargs.nhead_q;
|
||||
const auto n_splits = kargs.nsplits_ptr[0];
|
||||
index_t job_id = begin_job_id;
|
||||
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
|
||||
do
|
||||
{ // loop over jobs assigned to this worker
|
||||
const index_t i_head_flatten = job_id / jobs_per_head;
|
||||
const index_t i_tile_n_ = job_id % jobs_per_head;
|
||||
const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head);
|
||||
const index_t i_batch = i_head_flatten / kargs.nhead_q;
|
||||
const index_t i_nhead = i_head_flatten % kargs.nhead_q;
|
||||
|
||||
if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head
|
||||
i_split = 0;
|
||||
run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split, n_splits);
|
||||
} while(++job_id < end_job_id);
|
||||
if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head
|
||||
i_split = 0;
|
||||
run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split, n_splits);
|
||||
} while(++job_id < end_job_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Group mode persistent: variable seqlen per batch, dispatch via gist algo.
|
||||
// Each CU independently determines its workload interval using prefix_batch.
|
||||
const index_t cu_id = blockIdx.x;
|
||||
const index_t num_cu = gridDim.x;
|
||||
const index_t nbatch = kargs.batch;
|
||||
|
||||
// prefix_batch[nbatch] = total workload (nhead * sum(nc[b] * sq[b]))
|
||||
const index_t total_w = kargs.prefix_batch_ptr[nbatch];
|
||||
if(total_w == 0)
|
||||
return;
|
||||
const index_t target_w = integer_divide_ceil(total_w, num_cu);
|
||||
|
||||
const index_t w_lo = amd_wave_read_first_lane(cu_id * target_w);
|
||||
const index_t w_hi = amd_wave_read_first_lane(
|
||||
min(static_cast<index_t>((cu_id + 1) * target_w), total_w));
|
||||
if(w_lo >= total_w)
|
||||
return; // this CU has no work
|
||||
|
||||
for(index_t ibatch = kargs.cu_start_ibatch_ptr[cu_id]; ibatch < nbatch;
|
||||
++ibatch)
|
||||
{
|
||||
const index_t pb = kargs.prefix_batch_ptr[ibatch];
|
||||
if(pb >= w_hi)
|
||||
break; // all remaining batches are past this CU's interval
|
||||
|
||||
// per-batch seqlen: prefer seqlen_ptr if available, else diff seqstart
|
||||
const index_t sq = amd_wave_read_first_lane(
|
||||
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[ibatch]
|
||||
: (kargs.seqstart_q_ptr[ibatch + 1] -
|
||||
kargs.seqstart_q_ptr[ibatch]));
|
||||
const index_t sk = amd_wave_read_first_lane(
|
||||
kargs.seqlen_k_ptr ? kargs.seqlen_k_ptr[ibatch]
|
||||
: (kargs.seqstart_k_ptr[ibatch + 1] -
|
||||
kargs.seqstart_k_ptr[ibatch]));
|
||||
const index_t nc = integer_divide_ceil(sk, FmhaPipeline::kN0);
|
||||
const index_t hw = nc * sq; // workload per (batch, head) pair
|
||||
if(hw == 0)
|
||||
continue;
|
||||
const index_t nsplits_b =
|
||||
amd_wave_read_first_lane(kargs.nsplits_ptr[ibatch]);
|
||||
|
||||
// first head whose interval overlaps [w_lo, w_hi)
|
||||
const index_t head_start = max(static_cast<index_t>((w_lo - pb) / hw), 0);
|
||||
|
||||
for(index_t head_idx = head_start; head_idx < kargs.nhead_q; ++head_idx)
|
||||
{
|
||||
const index_t w_head = pb + head_idx * hw;
|
||||
if(w_head >= w_hi)
|
||||
return; // remaining heads are past the interval
|
||||
|
||||
// wc_start: workload offset of this CU's start relative to head start.
|
||||
// Used for both isplit and c_start.
|
||||
const index_t wc_start = max(static_cast<index_t>(w_lo - w_head), 0);
|
||||
// isplit = rank of this CU among all CUs touching this head
|
||||
const index_t isplit = integer_divide_ceil(wc_start, target_w);
|
||||
const index_t c_start =
|
||||
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
|
||||
const index_t c_end = integer_divide_ceil(min(hw, w_hi - w_head), sq);
|
||||
|
||||
for(index_t chunk_idx = c_start; chunk_idx < c_end; ++chunk_idx)
|
||||
run_(kargs, dim3(chunk_idx, head_idx, ibatch), isplit, nsplits_b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1229,14 +1381,23 @@ struct FmhaBwdDQDKDVKernel
|
||||
const long_index_t split_stride = kargs.seqlen_q * kargs.hdim_q;
|
||||
const auto nsplits = [&]() {
|
||||
if constexpr(!kIsGroupMode)
|
||||
return n_splits;
|
||||
return n_splits; // batch persistent: passed from nsplits_ptr[0]
|
||||
else if constexpr(kUsePersistent)
|
||||
return n_splits; // group persistent: passed from nsplits_ptr[ibatch]
|
||||
else
|
||||
return integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0);
|
||||
return integer_divide_ceil(kargs.seqlen_k,
|
||||
FmhaPipeline::kN0); // group non-persistent
|
||||
}();
|
||||
return batch_offset_dq_acc + (i_nhead_ * nsplits + i_split) * split_stride;
|
||||
}
|
||||
}();
|
||||
|
||||
// kUseKSplit && !kUsePersistent is true only for QrQtrDor+deterministic,
|
||||
// which writes dq directly (not through dq_acc splits) — use 'set'.
|
||||
// All other deterministic paths are persistent and use 'atomic_add':
|
||||
// a single CU may process multiple chunks of the same (batch, head, isplit)
|
||||
// sequentially, so contributions must accumulate rather than overwrite.
|
||||
// Non-deterministic paths also use 'atomic_add' (kUseKSplit=false).
|
||||
constexpr auto DstInMemOp = conditional_expr<(kUseKSplit && !kUsePersistent)>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const index_t stride_dq_acc = [&]() {
|
||||
@@ -1879,7 +2040,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
|
||||
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
|
||||
static constexpr bool kUsePersistent = kIsDeterministic && !kIsGroupMode;
|
||||
static constexpr bool kUsePersistent = kIsDeterministic;
|
||||
using WorkspaceManager = FmhaBwdWorkspaceManager<AccDataType, kIsGroupMode, kIsDeterministic>;
|
||||
|
||||
// clang-format off
|
||||
|
||||
Reference in New Issue
Block a user