diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 8871167c38..ce3a5491ff 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -9,6 +9,7 @@ #include "ck_tile/utility/json_dump.hpp" #include +#include #include #include #include @@ -243,30 +244,6 @@ bwd_result fmha_bwd_run(mode_enum mode, const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : seqstart_k_host.back()); - const fmha_bwd_traits fmha_traits{ - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - data_type, - mode == mode_enum::group, - mask.type, - bias.type, - use_dbias, - p_drop > 0.0f, - s_randval, - deterministic, - (mode == mode_enum::group) ? seqstart_q_host.data() : nullptr, - (mode == mode_enum::group) ? seqstart_k_host.data() : nullptr, - }; - fmha_bwd_launcher launcher(fmha_traits); - const size_t ws_size = launcher.workspace_size; - ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor k_host( @@ -395,8 +372,37 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + const auto t0_launcher = std::chrono::high_resolution_clock::now(); + fmha_bwd_launcher launcher(fmha_bwd_traits{ + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + data_type, + mode == mode_enum::group, + mask.type, + bias.type, + use_dbias, + p_drop > 0.0f, + s_randval, + deterministic, + (mode == mode_enum::group) ? seqstart_q_host.data() : nullptr, + (mode == mode_enum::group) ? seqstart_k_host.data() : nullptr, + }); + const auto t1_launcher = std::chrono::high_resolution_clock::now(); + const double launcher_ctor_ms = + std::chrono::duration(t1_launcher - t0_launcher).count(); + const size_t ws_size = launcher.workspace_size; ck_tile::DeviceMem ws_buf(ws_size); + ck_tile::gpu_timer prepare_ws_timer; + prepare_ws_timer.start(nullptr); launcher.prepare_workspace(ws_buf.GetDeviceBuffer()); + prepare_ws_timer.stop(nullptr); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -442,7 +448,9 @@ bwd_result fmha_bwd_run(mode_enum mode, << (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval << ", deterministic:" << deterministic << ", workspace:" << std::to_string(workspace_size_in_megabytes) << "MiB" - << ", mask:" << mask << std::flush; + << ", mask:" << mask + << ", init:" << launcher_ctor_ms << "ms" + << ", prws:" << prepare_ws_timer.duration() << "ms" << std::flush; auto fmha_args = [&]() { /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 5c0763e348..c69e52b0af 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -27,6 +27,31 @@ namespace ck_tile { +// Per-CU state for group-mode deterministic persistent scheduling. +// Packed into a single array to reduce kargs pointer count (5 pointers → 1). +// alignas(16): enables aligned 128-bit loads; sizeof == 32 (6×4 + 8 pad). +struct alignas(16) FmhaBwdGroupPersistentCuState +{ + index_t w_lo; // global position of this CU's first K-chunk (= pb + head*hw + c*sq) + index_t w_hi; // global position of next CU's first K-chunk (exclusive upper bound) + index_t ibatch; // first batch this CU touches (batch_size = no work sentinel) + index_t isplit; // isplit for the first (batch, head) this CU touches + index_t head_start; // head index for the first batch this CU touches + index_t c_start; // chunk index for the first (batch, head) this CU touches + // 8 bytes implicit padding +}; + +// Per-batch precomputed values used in the group-mode persistent dispatch loop. +// Avoids per-iteration reads from seqstart_q/k_ptr and nsplits_ptr. +// alignas(16): sizeof == 16 (3×4 + 4 pad), fits in a single 128-bit load. +struct alignas(16) FmhaBwdBatchState +{ + index_t sq; // seqlen_q for this batch (seqstart_q[b+1] - seqstart_q[b]) + index_t nc; // number of K-chunks: ceil(seqlen_k / kN0) + index_t nsplits; // dq_acc split count for this batch + // 4 bytes implicit padding +}; + template struct FmhaBwdWorkspaceManager { @@ -77,10 +102,20 @@ struct FmhaBwdWorkspaceManager return integer_least_multiple(sizeof(index_t) * (batch + 1), ALIGNMENT); return 0; } - CK_TILE_HOST static size_t GetCuStartIbatchSize(const int num_cus) + // cu_state[num_cus]: per-CU persistent state packed into one array (group det only). + // Replaces separate cu_start_ibatch / cu_wlo / cu_isplit / cu_head_start / cu_c_start arrays. + CK_TILE_HOST static size_t GetCuStateSize(const int num_cus) { if constexpr(kIsGroupMode && kIsDeterministic) - return integer_least_multiple(sizeof(index_t) * num_cus, ALIGNMENT); + return integer_least_multiple(sizeof(FmhaBwdGroupPersistentCuState) * num_cus, + ALIGNMENT); + return 0; + } + // batch_state[batch]: per-batch sq/nc/nsplits for group det dispatch loop. + CK_TILE_HOST static size_t GetBatchStateSize(const int batch) + { + if constexpr(kIsGroupMode && kIsDeterministic) + return integer_least_multiple(sizeof(FmhaBwdBatchState) * batch, ALIGNMENT); return 0; } @@ -89,8 +124,11 @@ struct FmhaBwdWorkspaceManager { if constexpr(kUseQrQtrDorPipeline) return 0; - return GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch) + - GetPrefixBatchSize(batch) + GetCuStartIbatchSize(get_num_cus()); + const size_t raw = GetDqAccSplitsSize(batch) + + GetDqAccOffsetsSize(batch) + GetPrefixBatchSize(batch) + + GetCuStateSize(get_num_cus()) + GetBatchStateSize(batch); + // Pad to 4K so dq_acc buffer always starts on a page-aligned boundary. + return integer_least_multiple(raw, static_cast(4096)); } CK_TILE_HOST static size_t GetDqAccSplitsOffset(const int) { return 0; } @@ -103,10 +141,14 @@ struct FmhaBwdWorkspaceManager { return GetDqAccSplitsSize(batch) + GetDqAccOffsetsSize(batch); } - CK_TILE_HOST static size_t GetCuStartIbatchOffset(const int batch) + CK_TILE_HOST static size_t GetCuStateOffset(const int batch) { return GetPrefixBatchOffset(batch) + GetPrefixBatchSize(batch); } + CK_TILE_HOST static size_t GetBatchStateOffset(const int batch) + { + return GetCuStateOffset(batch) + GetCuStateSize(get_num_cus()); + } template CK_TILE_HOST static size_t GetDqAccDataOffset(const int batch) { @@ -155,9 +197,10 @@ struct FmhaBwdWorkspaceManager auto* prefix_batch = reinterpret_cast(reinterpret_cast(cpu_ws) + GetDqAccSplitsSize(batch_size) + GetDqAccOffsetsSize(batch_size)); - auto* cu_start_ibatch = reinterpret_cast( - reinterpret_cast(cpu_ws) + GetDqAccSplitsSize(batch_size) + - GetDqAccOffsetsSize(batch_size) + GetPrefixBatchSize(batch_size)); + auto* cu_states = reinterpret_cast( + reinterpret_cast(cpu_ws) + GetCuStateOffset(batch_size)); + auto* batch_states = reinterpret_cast( + reinterpret_cast(cpu_ws) + GetBatchStateOffset(batch_size)); prefix_batch[0] = 0; for(index_t b = 0; b < batch_size; ++b) @@ -168,15 +211,19 @@ struct FmhaBwdWorkspaceManager } const index_t target_w = integer_divide_ceil(prefix_batch[batch_size], num_cus); - // Step 2: compute nsplits[b] = per_batch_max_cus[b] (for dq_acc split dimension) + // Step 2: compute nsplits[b] and fill batch_states[b] (sq, nc, nsplits per batch) for(index_t b = 0; b < batch_size; ++b) { const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); const index_t rest_workload = (nc > 0) ? (nc - 1) * sq : 0; - nsplits[b] = 1 + (rest_workload > 0 && target_w > 0 + const index_t ns = 1 + (rest_workload > 0 && target_w > 0 ? integer_divide_ceil(rest_workload, target_w) : 0); + nsplits[b] = ns; + batch_states[b].sq = sq; + batch_states[b].nc = nc; + batch_states[b].nsplits = ns; } // Step 3: compute per-batch dq_acc offsets (compact layout, depends on nsplits) @@ -191,18 +238,65 @@ struct FmhaBwdWorkspaceManager offsets[i] + static_cast(nhead_q) * nsplits[i] * (seqstart_qs[i + 1] - seqstart_qs[i]) * hdim_q; - // Step 4: fill cu_start_ibatch via two-pointer scan O(batch + num_cus) + // Step 4: fill cu_states via two-pointer scan. + // w_lo = global position of the first K-chunk: pb + head_start*hw + c_start*sq. + // This makes w_chunk track true global K-chunk positions on GPU, so w_chunk < w_hi + // correctly identifies boundaries without off-by-one overlap between adjacent CUs. + // w_hi is set in a post-pass to cu_states[c+1].w_lo. index_t cu_lo = 0; for(index_t b = 0; b < batch_size; ++b) { + const index_t sq = seqstart_qs[b + 1] - seqstart_qs[b]; + const index_t nc = integer_divide_ceil(seqstart_ks[b + 1] - seqstart_ks[b], kN0); + const index_t hw = nc * sq; + const index_t pb = prefix_batch[b]; const index_t cu_hi = min(num_cus, integer_divide_ceil(prefix_batch[b + 1], target_w)); for(index_t c = cu_lo; c < cu_hi; ++c) - cu_start_ibatch[c] = b; + { + const index_t w_lo = c * target_w; + cu_states[c].ibatch = b; + if(hw > 0) + { + const index_t head_start = + max(static_cast((w_lo - pb) / hw), index_t(0)); + const index_t w_head = pb + head_start * hw; + const index_t wc_start = max(w_lo - w_head, index_t(0)); + const index_t c_start = + wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0; + cu_states[c].isplit = + wc_start > 0 ? integer_divide_ceil(wc_start, target_w) : 0; + cu_states[c].head_start = head_start; + cu_states[c].c_start = c_start; + // w_lo = true global start of first K-chunk for this CU + cu_states[c].w_lo = pb + head_start * hw + c_start * sq; + } + else + { + cu_states[c].isplit = 0; + cu_states[c].head_start = 0; + cu_states[c].c_start = 0; + cu_states[c].w_lo = pb; // hw==0: degenerate, w_lo=batch start + } + } cu_lo = cu_hi; } + // Inactive CUs: use total_w as w_lo sentinel so the post-pass sets + // the last active CU's w_hi = total_w correctly. + const index_t total_w = prefix_batch[batch_size]; for(index_t c = cu_lo; c < num_cus; ++c) - cu_start_ibatch[c] = batch_size; // sentinel: this CU has no work + { + cu_states[c].w_lo = total_w; + cu_states[c].w_hi = total_w; + cu_states[c].ibatch = batch_size; // sentinel → early return on GPU + cu_states[c].isplit = 0; + cu_states[c].head_start = 0; + cu_states[c].c_start = 0; + } + // Post-pass: set w_hi[c] = w_lo[c+1] (global start of next CU's first K-chunk). + for(index_t c = 0; c < num_cus - 1; ++c) + cu_states[c].w_hi = cu_states[c + 1].w_lo; + cu_states[num_cus - 1].w_hi = total_w; return sizeof(AccDataType) * dq_acc_elems; } @@ -525,8 +619,9 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t batch; // used for persistent kernel implementation const ck_tile::index_t* nsplits_ptr; // per-batch nsplits (group) or single scalar (batch) // group mode persistent scheduling tables (read from CPU workspace by GPU): - const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1] - const ck_tile::index_t* cu_start_ibatch_ptr; // first batch for each CU, size [num_cus] + const ck_tile::index_t* prefix_batch_ptr; // prefix sum of nhead*hw[b], size [batch+1] + const FmhaBwdGroupPersistentCuState* cu_state_ptr; // per-CU packed state, size [num_cus] + const FmhaBwdBatchState* batch_state_ptr; // per-batch sq/nc/nsplits, size [batch] }; struct FmhaBwdBatchModeKargs @@ -952,8 +1047,13 @@ struct FmhaBwdDQDKDVKernel ws + WorkspaceManager::GetDqAccSplitsOffset(batch)); kargs.prefix_batch_ptr = reinterpret_cast( ws + WorkspaceManager::GetPrefixBatchOffset(batch)); - kargs.cu_start_ibatch_ptr = reinterpret_cast( - ws + WorkspaceManager::GetCuStartIbatchOffset(batch)); + if constexpr(kIsGroupMode) + { + kargs.cu_state_ptr = reinterpret_cast( + ws + WorkspaceManager::GetCuStateOffset(batch)); + kargs.batch_state_ptr = reinterpret_cast( + ws + WorkspaceManager::GetBatchStateOffset(batch)); + } } return kargs; @@ -1010,6 +1110,14 @@ struct FmhaBwdDQDKDVKernel { static_assert(!kUseQrQtrDorPipeline, "Persistent kernel is not compatible with QR/QTR/DOR pipeline"); + + // 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case + constexpr auto tile_n_interleave = [](index_t x, index_t n) { + if constexpr(kHasMask == false) + return x; + else + return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2); + }; if constexpr(!kIsGroupMode) { // Batch mode persistent: uniform seqlen_k across all batches @@ -1027,14 +1135,6 @@ struct FmhaBwdDQDKDVKernel return; // worker_id exceeds total jobs, exit early const index_t end_job_id = min((worker_id + 1) * jobs_per_worker, total_jobs); - // 0,1,2,3,4,5 ==> 0,5,1,4,2,3 for load balance in triangular mask case - constexpr auto tile_n_interleave = [](index_t x, index_t n) { - if constexpr(kHasMask == false) - return x; - else - return x % 2 == 0 ? (x / 2) : (n - 1 - x / 2); - }; - const auto n_splits = kargs.nsplits_ptr[0]; index_t job_id = begin_job_id; index_t i_split = integer_divide_ceil(job_id % jobs_per_head, jobs_per_worker); @@ -1054,68 +1154,60 @@ struct FmhaBwdDQDKDVKernel else { // Group mode persistent: variable seqlen per batch, dispatch via gist algo. - // Each CU independently determines its workload interval using prefix_batch. - const index_t cu_id = blockIdx.x; - const index_t num_cu = gridDim.x; - const index_t nbatch = kargs.batch; + // Per-CU state (w_lo, w_hi, ibatch, isplit, head_start, c_start) is packed + // in a single struct array to minimise kargs pointer count (5 → 1). + // Remap block→CU: interleave SEs so consecutive blocks hit different SEs, + // spreading dq_acc writes across HBM channels. + const index_t cu_id = blockIdx.x / 8 + (blockIdx.x % 8) * 32; - // prefix_batch[nbatch] = total workload (nhead * sum(nc[b] * sq[b])) - const index_t total_w = kargs.prefix_batch_ptr[nbatch]; - if(total_w == 0) - return; - const index_t target_w = integer_divide_ceil(total_w, num_cu); + // Load all per-CU fields through a single pointer; pointer dies after loads. + const FmhaBwdGroupPersistentCuState* cs = kargs.cu_state_ptr + cu_id; + const index_t w_hi = amd_wave_read_first_lane(cs->w_hi); + index_t ibatch = amd_wave_read_first_lane(cs->ibatch); + index_t isplit = amd_wave_read_first_lane(cs->isplit); + index_t head_start = amd_wave_read_first_lane(cs->head_start); + index_t c_start_0 = amd_wave_read_first_lane(cs->c_start); + index_t w_chunk = amd_wave_read_first_lane(cs->w_lo); + if(ibatch >= kargs.batch) + return; // this CU has no work (sentinel: ibatch == batch_size) - const index_t w_lo = amd_wave_read_first_lane(cu_id * target_w); - const index_t w_hi = amd_wave_read_first_lane( - min(static_cast((cu_id + 1) * target_w), total_w)); - if(w_lo >= total_w) - return; // this CU has no work - - for(index_t ibatch = kargs.cu_start_ibatch_ptr[cu_id]; ibatch < nbatch; - ++ibatch) + // w_chunk tracks the global K-chunk position; the loop exits when w_chunk + // reaches w_hi (this CU's exclusive upper bound). ibatch < batch is guaranteed + // on entry; the check inside guards against the rare case where head_start + // reaches nhead_q and ibatch is incremented past batch before w_chunk catches + // up. + do { - const index_t pb = kargs.prefix_batch_ptr[ibatch]; - if(pb >= w_hi) - break; // all remaining batches are past this CU's interval + if(ibatch >= kargs.batch) + return; + // sq/nc/nsplits are read inline (not pre-hoisted) to shorten their live + // range across the inlined run_() body, reducing SGPR pressure. + const FmhaBwdBatchState* bs = kargs.batch_state_ptr + ibatch; - // per-batch seqlen: prefer seqlen_ptr if available, else diff seqstart - const index_t sq = amd_wave_read_first_lane( - kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[ibatch] - : (kargs.seqstart_q_ptr[ibatch + 1] - - kargs.seqstart_q_ptr[ibatch])); - const index_t sk = amd_wave_read_first_lane( - kargs.seqlen_k_ptr ? kargs.seqlen_k_ptr[ibatch] - : (kargs.seqstart_k_ptr[ibatch + 1] - - kargs.seqstart_k_ptr[ibatch])); - const index_t nc = integer_divide_ceil(sk, FmhaPipeline::kN0); - const index_t hw = nc * sq; // workload per (batch, head) pair - if(hw == 0) - continue; - const index_t nsplits_b = - amd_wave_read_first_lane(kargs.nsplits_ptr[ibatch]); - - // first head whose interval overlaps [w_lo, w_hi) - const index_t head_start = max(static_cast((w_lo - pb) / hw), 0); - - for(index_t head_idx = head_start; head_idx < kargs.nhead_q; ++head_idx) + while(head_start < kargs.nhead_q) { - const index_t w_head = pb + head_idx * hw; - if(w_head >= w_hi) - return; // remaining heads are past the interval - - // wc_start: workload offset of this CU's start relative to head start. - // Used for both isplit and c_start. - const index_t wc_start = max(static_cast(w_lo - w_head), 0); - // isplit = rank of this CU among all CUs touching this head - const index_t isplit = integer_divide_ceil(wc_start, target_w); - const index_t c_start = - wc_start > 0 ? integer_divide_ceil(wc_start, sq) : 0; - const index_t c_end = integer_divide_ceil(min(hw, w_hi - w_head), sq); - - for(index_t chunk_idx = c_start; chunk_idx < c_end; ++chunk_idx) - run_(kargs, dim3(chunk_idx, head_idx, ibatch), isplit, nsplits_b); + while(c_start_0 < amd_wave_read_first_lane(bs->nc) && w_chunk < w_hi) + { + run_(kargs, + dim3(tile_n_interleave(c_start_0, + amd_wave_read_first_lane(bs->nc)), + head_start, + ibatch), + isplit, + amd_wave_read_first_lane(bs->nsplits)); + w_chunk += amd_wave_read_first_lane(bs->sq); + ++c_start_0; + } + if(w_chunk >= w_hi) + return; + // w_chunk is now at the start of the next head + c_start_0 = 0; + isplit = 0; + ++head_start; } - } + head_start = 0; + ++ibatch; + } while(w_chunk < w_hi); } } }