[fmha-bwd] Implement group-mode persistent scheduling with optimized state management

This commit is contained in:
Ding, Yi
2026-04-16 22:27:25 -05:00
parent 1a9404ac96
commit 92f2ed758e
2 changed files with 206 additions and 106 deletions

View File

@@ -9,6 +9,7 @@
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <chrono>
#include <cstring>
#include <functional>
#include <numeric>
@@ -243,30 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back());
const fmha_bwd_traits fmha_traits{
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
};
fmha_bwd_launcher launcher(fmha_traits);
const size_t ws_size = launcher.workspace_size;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
@@ -395,8 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode,
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
const auto t0_launcher = std::chrono::high_resolution_clock::now();
fmha_bwd_launcher launcher(fmha_bwd_traits{
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic,
(mode == mode_enum::group) ? seqstart_q_host.data() : nullptr,
(mode == mode_enum::group) ? seqstart_k_host.data() : nullptr,
});
const auto t1_launcher = std::chrono::high_resolution_clock::now();
const double launcher_ctor_ms =
std::chrono::duration<double, std::milli>(t1_launcher - t0_launcher).count();
const size_t ws_size = launcher.workspace_size;
ck_tile::DeviceMem ws_buf(ws_size);
ck_tile::gpu_timer prepare_ws_timer;
prepare_ws_timer.start(nullptr);
launcher.prepare_workspace(ws_buf.GetDeviceBuffer());
prepare_ws_timer.stop(nullptr);
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
@@ -442,7 +448,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic
<< ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB"
<< ", mask:" << mask << std::flush;
<< ", mask:" << mask
<< ", init:" << launcher_ctor_ms << "ms"
<< ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush;
auto fmha_args = [&]() {
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,

View File

@@ -27,6 +27,31 @@
namespace ck_tile {
// Per-CU state for group-mode deterministic persistent scheduling.
// Packed into a single array to reduce kargs pointer count (5 pointers → 1).
// alignas(16): enables aligned 128-bit loads; sizeof == 32 (6×4 + 8 pad).
struct alignas(16) FmhaBwdGroupPersistentCuState
{
index_t w_lo; // global position of this CU's first K-chunk (= pb + head*hw + c*sq)
index_t w_hi; // global position of next CU's first K-chunk (exclusive upper bound)
index_t ibatch; // first batch this CU touches (batch_size = no work sentinel)
index_t isplit; // isplit for the first (batch, head) this CU touches
index_t head_start; // head index for the first batch this CU touches
index_t c_start; // chunk index for the first (batch, head) this CU touches
// 8 bytes implicit padding
};
// Per-batch precomputed values used in the group-mode persistent dispatch loop.
// Avoids per-iteration reads from seqstart_q/k_ptr and nsplits_ptr.
// alignas(16): sizeof == 16 (3×4 + 4 pad), fits in a single 128-bit load.
struct alignas(16) FmhaBwdBatchState
{
index_t sq; // seqlen_q for this batch (seqstart_q[b+1] - seqstart_q[b])
index_t nc; // number of K-chunks: ceil(seqlen_k / kN0)
index_t nsplits; // dq_acc split count for this batch
// 4 bytes implicit padding
};
template <typename AccDataType, bool kIsGroupMode, bool kIsDeterministic>
struct FmhaBwdWorkspaceManager
{
@@ -77,10 +102,20 @@ struct FmhaBwdWorkspaceManager
return integer_least_multiple(sizeof(index_t) * (batch + 1), ALIGNMENT);
return 0;
}
CK_TILE_HOST static size_t GetCuStartIbatchSize(const int num_cus)
// cu_state[num_cus]: per-CU persistent state packed into one array (group det only).
// Replaces separate cu_start_ibatch / cu_wlo / cu_isplit / cu_head_start / cu_c_start arrays.
CK_TILE_HOST static size_t GetCuStateSize(const int num_cus)
{
if constexpr(kIsGroupMode && kIsDeterministic)
return integer_least_multiple(sizeof(index_t) * num_cus, ALIGNMENT);
return integer_least_multiple(sizeof(FmhaBwdGroupPersistentCuState) * num_cus,
ALIGNMENT);
return 0;
}
// batch_state[batch]: per-batch sq/nc/nsplits for group det dispatch loop.
CK_TILE_HOST static size_t GetBatchStateSize(const int batch)
{
if constexpr(kIsGroupMode && kIsDeterministic)
return integer_least_multiple(sizeof(FmhaBwdBatchState) * batch, ALIGNMENT);
return 0;
}
@@ -89,8 +124,11 @@ struct FmhaBwdWorkspaceManager
{
if constexpr(kUseQrQtrDorPipeline)
return 0;
return GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) + GetDqAccOffsetsSize(batch) +
GetPrefixBatchSize(batch) + GetCuStartIbatchSize(get_num_cus());
const size_t raw = GetDqAccSplitsSize<kUseQrQtrDorPipeline>(batch) +
GetDqAccOffsetsSize(batch) + GetPrefixBatchSize(batch) +
GetCuStateSize(get_num_cus()) + GetBatchStateSize(batch);
// Pad to 4K so dq_acc buffer always starts on a page-aligned boundary.
return integer_least_multiple(raw, static_cast<size_t>(4096));
}
CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; }
@@ -103,10 +141,14 @@ struct FmhaBwdWorkspaceManager
{
return GetDqAccSplitsSize<false>(batch) + GetDqAccOffsetsSize(batch);
}
CK_TILE_HOST static size_t GetCuStartIbatchOffset(const int batch)
CK_TILE_HOST static size_t GetCuStateOffset(const int batch)
{
return GetPrefixBatchOffset(batch) + GetPrefixBatchSize(batch);
}
CK_TILE_HOST static size_t GetBatchStateOffset(const int batch)
{
return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus());
}
template <bool kUseQrQtrDorPipeline>
CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch)
{
@@ -155,9 +197,10 @@ struct FmhaBwdWorkspaceManager
auto* prefix_batch = reinterpret_cast<index_t*>(reinterpret_cast<char*>(cpu_ws) +
GetDqAccSplitsSize<false>(batch_size) +
GetDqAccOffsetsSize(batch_size));
auto* cu_start_ibatch = reinterpret_cast<index_t*>(
reinterpret_cast<char*>(cpu_ws) + GetDqAccSplitsSize<false>(batch_size) +
GetDqAccOffsetsSize(batch_size) + GetPrefixBatchSize(batch_size));
auto* cu_states = reinterpret_cast<FmhaBwdGroupPersistentCuState*>(
reinterpret_cast<char*>(cpu_ws) + GetCuStateOffset(batch_size));
auto* batch_states = reinterpret_cast<FmhaBwdBatchState*>(
reinterpret_cast<char*>(cpu_ws) + GetBatchStateOffset(batch_size));
prefix_batch[0] = 0;
for(index_t b = 0; b < batch_size; ++b)
@@ -168,15 +211,19 @@ struct FmhaBwdWorkspaceManager
}
const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus);
// Step 2: compute nsplits[b] = per_batch_max_cus[b] (for dq_acc split dimension)
// Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch)
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0;
nsplits[b] = 1 + (rest_workload > 0 && target_w > 0
const index_t ns = 1 + (rest_workload > 0 && target_w > 0
? integer_divide_ceil(rest_workload, target_w)
: 0);
nsplits[b] = ns;
batch_states[b].sq = sq;
batch_states[b].nc = nc;
batch_states[b].nsplits = ns;
}
// Step 3: compute per-batch dq_acc offsets (compact layout, depends on nsplits)
@@ -191,18 +238,65 @@ struct FmhaBwdWorkspaceManager
offsets[i] + static_cast<long_index_t>(nhead_q) * nsplits[i] *
(seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q;
// Step 4: fill cu_start_ibatch via two-pointer scan O(batch + num_cus)
// Step 4: fill cu_states via two-pointer scan.
// w_lo = global position of the first K-chunk: pb + head_start*hw + c_start*sq.
// This makes w_chunk track true global K-chunk positions on GPU, so w_chunk < w_hi
// correctly identifies boundaries without off-by-one overlap between adjacent CUs.
// w_hi is set in a post-pass to cu_states[c+1].w_lo.
index_t cu_lo = 0;
for(index_t b = 0; b < batch_size; ++b)
{
const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b];
const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0);
const index_t hw = nc * sq;
const index_t pb = prefix_batch[b];
const index_t cu_hi =
min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w));
for(index_t c = cu_lo; c < cu_hi; ++c)
cu_start_ibatch[c] = b;
{
const index_t w_lo = c * target_w;
cu_states[c].ibatch = b;
if(hw > 0)
{
const index_t head_start =
max(static_cast<index_t>((w_lo - pb) / hw), index_t(0));
const index_t w_head = pb + head_start * hw;
const index_t wc_start = max(w_lo - w_head, index_t(0));
const index_t c_start =
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
cu_states[c].isplit =
wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0;
cu_states[c].head_start = head_start;
cu_states[c].c_start = c_start;
// w_lo = true global start of first K-chunk for this CU
cu_states[c].w_lo = pb + head_start * hw + c_start * sq;
}
else
{
cu_states[c].isplit = 0;
cu_states[c].head_start = 0;
cu_states[c].c_start = 0;
cu_states[c].w_lo = pb; // hw==0: degenerate, w_lo=batch start
}
}
cu_lo = cu_hi;
}
// Inactive CUs: use total_w as w_lo sentinel so the post-pass sets
// the last active CU's w_hi = total_w correctly.
const index_t total_w = prefix_batch[batch_size];
for(index_t c = cu_lo; c < num_cus; ++c)
cu_start_ibatch[c] = batch_size; // sentinel: this CU has no work
{
cu_states[c].w_lo = total_w;
cu_states[c].w_hi = total_w;
cu_states[c].ibatch = batch_size; // sentinel → early return on GPU
cu_states[c].isplit = 0;
cu_states[c].head_start = 0;
cu_states[c].c_start = 0;
}
// Post-pass: set w_hi[c] = w_lo[c+1] (global start of next CU's first K-chunk).
for(index_t c = 0; c < num_cus - 1; ++c)
cu_states[c].w_hi = cu_states[c + 1].w_lo;
cu_states[num_cus - 1].w_hi = total_w;
return sizeof(AccDataType) * dq_acc_elems;
}
@@ -525,8 +619,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch; // used for persistent kernel implementation
const ck_tile::index_t* nsplits_ptr; // per-batch nsplits (group) or single scalar (batch)
// group mode persistent scheduling tables (read from CPU workspace by GPU):
const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1]
const ck_tile::index_t* cu_start_ibatch_ptr; // first batch for each CU, size [num_cus]
const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1]
const FmhaBwdGroupPersistentCuState* cu_state_ptr; // per-CU packed state, size [num_cus]
const FmhaBwdBatchState* batch_state_ptr; // per-batch sq/nc/nsplits, size [batch]
};
struct FmhaBwdBatchModeKargs
@@ -952,8 +1047,13 @@ struct FmhaBwdDQDKDVKernel
ws + WorkspaceManager::GetDqAccSplitsOffset(batch));
kargs.prefix_batch_ptr = reinterpret_cast<const ck_tile::index_t*>(
ws + WorkspaceManager::GetPrefixBatchOffset(batch));
kargs.cu_start_ibatch_ptr = reinterpret_cast<const ck_tile::index_t*>(
ws + WorkspaceManager::GetCuStartIbatchOffset(batch));
if constexpr(kIsGroupMode)
{
kargs.cu_state_ptr = reinterpret_cast<const FmhaBwdGroupPersistentCuState*>(
ws + WorkspaceManager::GetCuStateOffset(batch));
kargs.batch_state_ptr = reinterpret_cast<const FmhaBwdBatchState*>(
ws + WorkspaceManager::GetBatchStateOffset(batch));
}
}
return kargs;
@@ -1010,6 +1110,14 @@ struct FmhaBwdDQDKDVKernel
{
static_assert(!kUseQrQtrDorPipeline,
"Persistent kernel is not compatible with QR/QTR/DOR pipeline");
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
if constexpr(kHasMask == false)
return x;
else
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
};
if constexpr(!kIsGroupMode)
{
// Batch mode persistent: uniform seqlen_k across all batches
@@ -1027,14 +1135,6 @@ struct FmhaBwdDQDKDVKernel
return; // worker_id exceeds total jobs, exit early
const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs);
// 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case
constexpr auto tile_n_interleave = [](index_t x, index_t n) {
if constexpr(kHasMask == false)
return x;
else
return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2);
};
const auto n_splits = kargs.nsplits_ptr[0];
index_t job_id = begin_job_id;
index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker);
@@ -1054,68 +1154,60 @@ struct FmhaBwdDQDKDVKernel
else
{
// Group mode persistent: variable seqlen per batch, dispatch via gist algo.
// Each CU independently determines its workload interval using prefix_batch.
const index_t cu_id = blockIdx.x;
const index_t num_cu = gridDim.x;
const index_t nbatch = kargs.batch;
// Per-CU state (w_lo, w_hi, ibatch, isplit, head_start, c_start) is packed
// in a single struct array to minimise kargs pointer count (5 → 1).
// Remap block→CU: interleave SEs so consecutive blocks hit different SEs,
// spreading dq_acc writes across HBM channels.
const index_t cu_id = blockIdx.x / 8 + (blockIdx.x % 8) * 32;
// prefix_batch[nbatch] = total workload (nhead * sum(nc[b] * sq[b]))
const index_t total_w = kargs.prefix_batch_ptr[nbatch];
if(total_w == 0)
return;
const index_t target_w = integer_divide_ceil(total_w, num_cu);
// Load all per-CU fields through a single pointer; pointer dies after loads.
const FmhaBwdGroupPersistentCuState* cs = kargs.cu_state_ptr + cu_id;
const index_t w_hi = amd_wave_read_first_lane(cs->w_hi);
index_t ibatch = amd_wave_read_first_lane(cs->ibatch);
index_t isplit = amd_wave_read_first_lane(cs->isplit);
index_t head_start = amd_wave_read_first_lane(cs->head_start);
index_t c_start_0 = amd_wave_read_first_lane(cs->c_start);
index_t w_chunk = amd_wave_read_first_lane(cs->w_lo);
if(ibatch >= kargs.batch)
return; // this CU has no work (sentinel: ibatch == batch_size)
const index_t w_lo = amd_wave_read_first_lane(cu_id * target_w);
const index_t w_hi = amd_wave_read_first_lane(
min(static_cast<index_t>((cu_id + 1) * target_w), total_w));
if(w_lo >= total_w)
return; // this CU has no work
for(index_t ibatch = kargs.cu_start_ibatch_ptr[cu_id]; ibatch < nbatch;
++ibatch)
// w_chunk tracks the global K-chunk position; the loop exits when w_chunk
// reaches w_hi (this CU's exclusive upper bound). ibatch < batch is guaranteed
// on entry; the check inside guards against the rare case where head_start
// reaches nhead_q and ibatch is incremented past batch before w_chunk catches
// up.
do
{
const index_t pb = kargs.prefix_batch_ptr[ibatch];
if(pb >= w_hi)
break; // all remaining batches are past this CU's interval
if(ibatch >= kargs.batch)
return;
// sq/nc/nsplits are read inline (not pre-hoisted) to shorten their live
// range across the inlined run_() body, reducing SGPR pressure.
const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch;
// per-batch seqlen: prefer seqlen_ptr if available, else diff seqstart
const index_t sq = amd_wave_read_first_lane(
kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[ibatch]
: (kargs.seqstart_q_ptr[ibatch + 1] -
kargs.seqstart_q_ptr[ibatch]));
const index_t sk = amd_wave_read_first_lane(
kargs.seqlen_k_ptr ? kargs.seqlen_k_ptr[ibatch]
: (kargs.seqstart_k_ptr[ibatch + 1] -
kargs.seqstart_k_ptr[ibatch]));
const index_t nc = integer_divide_ceil(sk, FmhaPipeline::kN0);
const index_t hw = nc * sq; // workload per (batch, head) pair
if(hw == 0)
continue;
const index_t nsplits_b =
amd_wave_read_first_lane(kargs.nsplits_ptr[ibatch]);
// first head whose interval overlaps [w_lo, w_hi)
const index_t head_start = max(static_cast<index_t>((w_lo - pb) / hw), 0);
for(index_t head_idx = head_start; head_idx < kargs.nhead_q; ++head_idx)
while(head_start < kargs.nhead_q)
{
const index_t w_head = pb + head_idx * hw;
if(w_head >= w_hi)
return; // remaining heads are past the interval
// wc_start: workload offset of this CU's start relative to head start.
// Used for both isplit and c_start.
const index_t wc_start = max(static_cast<index_t>(w_lo - w_head), 0);
// isplit = rank of this CU among all CUs touching this head
const index_t isplit = integer_divide_ceil(wc_start, target_w);
const index_t c_start =
wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0;
const index_t c_end = integer_divide_ceil(min(hw, w_hi - w_head), sq);
for(index_t chunk_idx = c_start; chunk_idx < c_end; ++chunk_idx)
run_(kargs, dim3(chunk_idx, head_idx, ibatch), isplit, nsplits_b);
while(c_start_0 < amd_wave_read_first_lane(bs->nc) && w_chunk < w_hi)
{
run_(kargs,
dim3(tile_n_interleave(c_start_0,
amd_wave_read_first_lane(bs->nc)),
head_start,
ibatch),
isplit,
amd_wave_read_first_lane(bs->nsplits));
w_chunk += amd_wave_read_first_lane(bs->sq);
++c_start_0;
}
if(w_chunk >= w_hi)
return;
// w_chunk is now at the start of the next head
c_start_0 = 0;
isplit = 0;
++head_start;
}
}
head_start = 0;
++ibatch;
} while(w_chunk < w_hi);
}
}
}