From 1a9404ac96ca63ca85c6efc76c2bd7f690a24072 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 8 Apr 2026 03:38:17 -0500 Subject: [PATCH] [CK_TILE] Use Persistent Scheduling for FMHA BWD Group Deterministic --- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 277 ++++++++++++++---- 1 file changed, 219 insertions(+), 58 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 3c33caf338..5c0763e348 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -39,6 +39,12 @@ struct FmhaBwdWorkspaceManager // long_index_t dq_acc_offsets[batch] // — per-batch offset array + // [OPTIONAL, only for deterministic group mode persistent] + // index_t prefix_batch[batch+1] + // — prefix sum of nhead * num_chunks[b] * seqlen_q[b] across batches + // index_t cu_start_ibatch[num_cus] + // — first batch index that overlaps each CU's workload interval + // GPU WORKSPACE BELOW (read & written by kernels): // [OPTIONAL, only for !kUseQrQtrDorPipeline] @@ -51,7 +57,7 @@ struct FmhaBwdWorkspaceManager static constexpr size_t ALIGNMENT = 16; template - CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsSize(const int batch) + CK_TILE_HOST static size_t GetDqAccSplitsSize(const int batch) { if constexpr(kUseQrQtrDorPipeline) return 0; @@ -59,28 +65,50 @@ struct FmhaBwdWorkspaceManager (kIsGroupMode && kIsDeterministic) ? static_cast(batch) : 1; return integer_least_multiple(sizeof(index_t) * dqAccSplitsElems, ALIGNMENT); } - CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsSize(const int batch) + CK_TILE_HOST static size_t GetDqAccOffsetsSize(const int batch) { const auto dqAccOffsetsElems = (kIsGroupMode && kIsDeterministic) ? static_cast(batch) : 0; return integer_least_multiple(sizeof(long_index_t) * dqAccOffsetsElems, ALIGNMENT); } + CK_TILE_HOST static size_t GetPrefixBatchSize(const int batch) + { + if constexpr(kIsGroupMode && kIsDeterministic) + return integer_least_multiple(sizeof(index_t) * (batch + 1), ALIGNMENT); + return 0; + } + CK_TILE_HOST static size_t GetCuStartIbatchSize(const int num_cus) + { + if constexpr(kIsGroupMode && kIsDeterministic) + return integer_least_multiple(sizeof(index_t) * num_cus, ALIGNMENT); + return 0; + } + template - CK_TILE_HOST_DEVICE static size_t GetWorkspaceHostSize(const int batch) + CK_TILE_HOST static size_t GetWorkspaceHostSize(const int batch) { if constexpr(kUseQrQtrDorPipeline) return 0; - return GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch); + return GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch) + + GetPrefixBatchSize(batch) + GetCuStartIbatchSize(get_num_cus()); } - CK_TILE_HOST_DEVICE static size_t GetDqAccSplitsOffset(const int) { return 0; } + CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; } template - CK_TILE_HOST_DEVICE static size_t GetDqAccOffsetsOffset(const int batch) + CK_TILE_HOST static size_t GetDqAccOffsetsOffset(const int batch) { return GetDqAccSplitsSize(batch); } + CK_TILE_HOST static size_t GetPrefixBatchOffset(const int batch) + { + return GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch); + } + CK_TILE_HOST static size_t GetCuStartIbatchOffset(const int batch) + { + return GetPrefixBatchOffset(batch) + GetPrefixBatchSize(batch); + } template - CK_TILE_HOST_DEVICE static size_t GetDqAccDataOffset(const int batch) + CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch) { return GetWorkspaceHostSize(batch); } @@ -121,21 +149,64 @@ struct FmhaBwdWorkspaceManager seqstart_qs[batch_size] * hdim_q; } else if constexpr(kIsGroupMode) - { // deterministic group mode + { // deterministic group mode (persistent) + // Step 1: compute prefix_batch and target_w using per-batch seqlens + const index_t num_cus = get_num_cus(); + auto* prefix_batch = reinterpret_cast(reinterpret_cast(cpu_ws) + + GetDqAccSplitsSize(batch_size) + + GetDqAccOffsetsSize(batch_size)); + auto* cu_start_ibatch = reinterpret_cast( + reinterpret_cast(cpu_ws) + GetDqAccSplitsSize(batch_size) + + GetDqAccOffsetsSize(batch_size) + GetPrefixBatchSize(batch_size)); + + prefix_batch[0] = 0; + for(index_t b = 0; b < batch_size; ++b) + { + const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; + const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); + prefix_batch[b + 1] = prefix_batch[b] + nhead_q * nc * sq; + } + const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus); + + // Step 2: compute nsplits[b] = per_batch_max_cus[b] (for dq_acc split dimension) + for(index_t b = 0; b < batch_size; ++b) + { + const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; + const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); + const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0; + nsplits[b] = 1 + (rest_workload > 0 && target_w > 0 + ? integer_divide_ceil(rest_workload, target_w) + : 0); + } + + // Step 3: compute per-batch dq_acc offsets (compact layout, depends on nsplits) offsets[0] = 0; index_t i = 0; for(; i < batch_size - 1; ++i) { - nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0); offsets[i + 1] = offsets[i] + static_cast(nhead_q) * nsplits[i] * (seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q; } - nsplits[i] = integer_divide_ceil(seqstart_ks[i + 1] - seqstart_ks[i], kN0); - return sizeof(AccDataType) * - (offsets[i] + static_cast(nhead_q) * nsplits[i] * - (seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q); + const long_index_t dq_acc_elems = + offsets[i] + static_cast(nhead_q) * nsplits[i] * + (seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q; + + // Step 4: fill cu_start_ibatch via two-pointer scan O(batch + num_cus) + index_t cu_lo = 0; + for(index_t b = 0; b < batch_size; ++b) + { + const index_t cu_hi = + min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w)); + for(index_t c = cu_lo; c < cu_hi; ++c) + cu_start_ibatch[c] = b; + cu_lo = cu_hi; + } + for(index_t c = cu_lo; c < num_cus; ++c) + cu_start_ibatch[c] = batch_size; // sentinel: this CU has no work + + return sizeof(AccDataType) * dq_acc_elems; } - else // deterministic non-group mode (kUsePersistent) + else // deterministic batch mode (kUsePersistent) { const index_t dqdqkdv_workers = get_num_cus(); const index_t jobs_per_head = integer_divide_ceil(seqlen_k, kN0); @@ -159,13 +230,14 @@ struct FmhaBwdWorkspaceManager size_t host_ws_size) { constexpr bool NeedsZeroDqAcc = []() { - constexpr bool kUsePersistent = - !kUseQrQtrDorPipeline && kIsDeterministic && !kIsGroupMode; - // non-deterministic and persistent kernels use atomic-add to write dq + constexpr bool kUsePersistent = !kUseQrQtrDorPipeline && kIsDeterministic; + // Persistent (batch and group): uses atomic_add → buffer must start at zero + // so that accumulated dq values are correct. + // Non-deterministic: uses atomic_add → buffer must start at zero. if constexpr(kUsePersistent || !kIsDeterministic) return true; - // Some block may be skipped with causal mask and dq are not set to zeros - // In these cases we need to zero out it first + // Non-persistent deterministic: uses set, but causal mask may skip some tiles + // leaving dq_acc slots unwritten — zero them out first. return kHasMask; }(); if(host_ws_size > 0) @@ -228,8 +300,7 @@ struct FmhaBwdDQDKDVKernel #else static constexpr bool kIsAvailable = !kUseTrLoad; #endif - static constexpr bool kUsePersistent = - kIsDeterministic && !kIsGroupMode && !kUseQrQtrDorPipeline; + static constexpr bool kUsePersistent = kIsDeterministic && !kUseQrQtrDorPipeline; using WorkspaceManager = FmhaBwdWorkspaceManager; // clang-format off @@ -452,7 +523,10 @@ struct FmhaBwdDQDKDVKernel struct FmhaBwdDeterministicKargs { ck_tile::index_t batch; // used for persistent kernel implementation - const ck_tile::index_t* nsplits_ptr; // points to nsplits[0] in workspace (batch mode) + const ck_tile::index_t* nsplits_ptr; // per-batch nsplits (group) or single scalar (batch) + // group mode persistent scheduling tables (read from CPU workspace by GPU): + const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1] + const ck_tile::index_t* cu_start_ibatch_ptr; // first batch for each CU, size [num_cus] }; struct FmhaBwdBatchModeKargs @@ -872,7 +946,15 @@ struct FmhaBwdDQDKDVKernel } if constexpr(kUsePersistent) - kargs.batch = batch; + { + kargs.batch = batch; + kargs.nsplits_ptr = reinterpret_cast( + ws + WorkspaceManager::GetDqAccSplitsOffset(batch)); + kargs.prefix_batch_ptr = reinterpret_cast( + ws + WorkspaceManager::GetPrefixBatchOffset(batch)); + kargs.cu_start_ibatch_ptr = reinterpret_cast( + ws + WorkspaceManager::GetCuStartIbatchOffset(batch)); + } return kargs; } @@ -928,43 +1010,113 @@ struct FmhaBwdDQDKDVKernel { static_assert(!kUseQrQtrDorPipeline, "Persistent kernel is not compatible with QR/QTR/DOR pipeline"); - const index_t worker_id = blockIdx.x; - const index_t worker_num = gridDim.x; + if constexpr(!kIsGroupMode) + { + // Batch mode persistent: uniform seqlen_k across all batches + const index_t worker_id = blockIdx.x; + const index_t worker_num = gridDim.x; - const index_t jobs_per_head = - integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); - const index_t total_heads = kargs.batch * kargs.nhead_q; - const index_t total_jobs = jobs_per_head * total_heads; - const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num); + const index_t jobs_per_head = + integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + const index_t total_heads = kargs.batch * kargs.nhead_q; + const index_t total_jobs = jobs_per_head * total_heads; + const index_t jobs_per_worker = integer_divide_ceil(total_jobs, worker_num); - const index_t begin_job_id = worker_id * jobs_per_worker; - if(begin_job_id >= total_jobs) - return; // worker_id exceeds total jobs, exit early - const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs); + const index_t begin_job_id = worker_id * jobs_per_worker; + if(begin_job_id >= total_jobs) + return; // worker_id exceeds total jobs, exit early + const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs); - // 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case - constexpr auto tile_n_interleave = [](index_t x, index_t n) { - if constexpr(kHasMask == false) - return x; - else - return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2); - }; + // 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case + constexpr auto tile_n_interleave = [](index_t x, index_t n) { + if constexpr(kHasMask == false) + return x; + else + return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2); + }; - const auto n_splits = kargs.nsplits_ptr[0]; - index_t job_id = begin_job_id; - index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker); - do - { // loop over jobs assigned to this worker - const index_t i_head_flatten = job_id / jobs_per_head; - const index_t i_tile_n_ = job_id % jobs_per_head; - const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head); - const index_t i_batch = i_head_flatten / kargs.nhead_q; - const index_t i_nhead = i_head_flatten % kargs.nhead_q; + const auto n_splits = kargs.nsplits_ptr[0]; + index_t job_id = begin_job_id; + index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker); + do + { // loop over jobs assigned to this worker + const index_t i_head_flatten = job_id / jobs_per_head; + const index_t i_tile_n_ = job_id % jobs_per_head; + const index_t i_tile_n = tile_n_interleave(i_tile_n_, jobs_per_head); + const index_t i_batch = i_head_flatten / kargs.nhead_q; + const index_t i_nhead = i_head_flatten % kargs.nhead_q; - if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head - i_split = 0; - run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split, n_splits); - } while(++job_id < end_job_id); + if(i_tile_n_ == 0) // reset dq_acc writing idx when starting a new head + i_split = 0; + run_(kargs, dim3(i_tile_n, i_nhead, i_batch), i_split, n_splits); + } while(++job_id < end_job_id); + } + else + { + // Group mode persistent: variable seqlen per batch, dispatch via gist algo. + // Each CU independently determines its workload interval using prefix_batch. + const index_t cu_id = blockIdx.x; + const index_t num_cu = gridDim.x; + const index_t nbatch = kargs.batch; + + // prefix_batch[nbatch] = total workload (nhead * sum(nc[b] * sq[b])) + const index_t total_w = kargs.prefix_batch_ptr[nbatch]; + if(total_w == 0) + return; + const index_t target_w = integer_divide_ceil(total_w, num_cu); + + const index_t w_lo = amd_wave_read_first_lane(cu_id * target_w); + const index_t w_hi = amd_wave_read_first_lane( + min(static_cast((cu_id + 1) * target_w), total_w)); + if(w_lo >= total_w) + return; // this CU has no work + + for(index_t ibatch = kargs.cu_start_ibatch_ptr[cu_id]; ibatch < nbatch; + ++ibatch) + { + const index_t pb = kargs.prefix_batch_ptr[ibatch]; + if(pb >= w_hi) + break; // all remaining batches are past this CU's interval + + // per-batch seqlen: prefer seqlen_ptr if available, else diff seqstart + const index_t sq = amd_wave_read_first_lane( + kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[ibatch] + : (kargs.seqstart_q_ptr[ibatch + 1] - + kargs.seqstart_q_ptr[ibatch])); + const index_t sk = amd_wave_read_first_lane( + kargs.seqlen_k_ptr ? kargs.seqlen_k_ptr[ibatch] + : (kargs.seqstart_k_ptr[ibatch + 1] - + kargs.seqstart_k_ptr[ibatch])); + const index_t nc = integer_divide_ceil(sk, FmhaPipeline::kN0); + const index_t hw = nc * sq; // workload per (batch, head) pair + if(hw == 0) + continue; + const index_t nsplits_b = + amd_wave_read_first_lane(kargs.nsplits_ptr[ibatch]); + + // first head whose interval overlaps [w_lo, w_hi) + const index_t head_start = max(static_cast((w_lo - pb) / hw), 0); + + for(index_t head_idx = head_start; head_idx < kargs.nhead_q; ++head_idx) + { + const index_t w_head = pb + head_idx * hw; + if(w_head >= w_hi) + return; // remaining heads are past the interval + + // wc_start: workload offset of this CU's start relative to head start. + // Used for both isplit and c_start. + const index_t wc_start = max(static_cast(w_lo - w_head), 0); + // isplit = rank of this CU among all CUs touching this head + const index_t isplit = integer_divide_ceil(wc_start, target_w); + const index_t c_start = + wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0; + const index_t c_end = integer_divide_ceil(min(hw, w_hi - w_head), sq); + + for(index_t chunk_idx = c_start; chunk_idx < c_end; ++chunk_idx) + run_(kargs, dim3(chunk_idx, head_idx, ibatch), isplit, nsplits_b); + } + } + } } } } @@ -1229,14 +1381,23 @@ struct FmhaBwdDQDKDVKernel const long_index_t split_stride = kargs.seqlen_q * kargs.hdim_q; const auto nsplits = [&]() { if constexpr(!kIsGroupMode) - return n_splits; + return n_splits; // batch persistent: passed from nsplits_ptr[0] + else if constexpr(kUsePersistent) + return n_splits; // group persistent: passed from nsplits_ptr[ibatch] else - return integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + return integer_divide_ceil(kargs.seqlen_k, + FmhaPipeline::kN0); // group non-persistent }(); return batch_offset_dq_acc + (i_nhead_ * nsplits + i_split) * split_stride; } }(); + // kUseKSplit && !kUsePersistent is true only for QrQtrDor+deterministic, + // which writes dq directly (not through dq_acc splits) — use 'set'. + // All other deterministic paths are persistent and use 'atomic_add': + // a single CU may process multiple chunks of the same (batch, head, isplit) + // sequentially, so contributions must accumulate rather than overwrite. + // Non-deterministic paths also use 'atomic_add' (kUseKSplit=false). constexpr auto DstInMemOp = conditional_expr<(kUseKSplit && !kUsePersistent)>( memory_operation_enum::set, memory_operation_enum::atomic_add); const index_t stride_dq_acc = [&]() { @@ -1879,7 +2040,7 @@ struct FmhaBwdConvertQGradKernel static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ; static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic; - static constexpr bool kUsePersistent = kIsDeterministic && !kIsGroupMode; + static constexpr bool kUsePersistent = kIsDeterministic; using WorkspaceManager = FmhaBwdWorkspaceManager; // clang-format off