diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 4ef7aaea21..22dca51734 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -663,6 +663,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } k_dram_window.init_raw(); + // Track the K/V SRD's current physical page so consecutive same-page rebases + // (the common case when kPageBlockSize >> kN0, e.g. page_size=1024 / kN0=128 + // gives 8 same-page tiles in a row) collapse to a no-op. -1 is a never-valid + // sentinel that forces the first rebase to run. + index_t last_k_page = static_cast(-1); + index_t last_v_page = static_cast(-1); + // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle // addressing. @@ -671,7 +678,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { // readfirstlane: make physical_page provably wave-uniform so the // resulting SRD lands in SGPRs (required by buffer load instructions). - physical_page = __builtin_amdgcn_readfirstlane(physical_page); + physical_page = __builtin_amdgcn_readfirstlane(physical_page); + // Skip the SRD reset (data ptr + num_records + init_raw) when the + // target page is the same as the one currently encoded in the SRD. + // last_k_page is also wave-uniform (initialized to -1 and only ever + // assigned from readfirstlane'd values), so the branch is wave-uniform + // and the compiler can keep last_k_page in SGPRs. + if(physical_page == last_k_page) + return; + last_k_page = physical_page; const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_; const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_k; @@ -695,6 +710,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // readfirstlane: make physical_page provably wave-uniform so the // resulting SRD lands in SGPRs (required by buffer load instructions). physical_page = __builtin_amdgcn_readfirstlane(physical_page); + // Same same-page-skip trick as rebase_k_window above. The V SRD is + // rebased multiple times per K-loop iter (initial setup, post-GEMM0, + // sink boundary, next-iter prep); for large page_size all of those + // typically stay on the same page. + if(physical_page == last_v_page) + return; + last_v_page = physical_page; const auto* base_ptr = v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; const auto* page_ptr = @@ -1113,17 +1135,30 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } // PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale. // s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page(j), k_slot(j), kv_head] - // Supports cross-page tiles (kPageBlockSize < kN0): column j is looked up in the - // page covering token (k_origin + j). + // + // Implementation notes: + // The naive form (one global load of qd + kd inside the per-element + // s_acc sweep) has two problems on gfx9 qr_async: (a) it inflates + // the inner-loop instruction footprint with multi-component address + // arithmetic that the compiler must keep alive per element, and (b) + // it puts the K SRD under SGPR pressure (same class of issue we hit + // in fmha_fwd qr_async). To dodge both we stage Q-row and K-col + // descales through LDS once per K-loop iteration; the per-element + // sweep then collapses to a pure LDS read + FP multiply, identical + // to the fmha_fwd qr_async PER_TOKEN_HEAD path + // (block_fmha_pipeline_qr_ks_vs_async.hpp). + // + // Supports cross-page tiles (kPageBlockSize < kN0): each LDS-load + // thread resolves its page index from `tile_k_pages` (1 entry on + // the fast path, kN0/kPageBlockSize entries on the cross-page path). else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { - const auto k_origin = k_dram_block_window.get_window_origin(); - const index_t qo_head = block_indices.qo_head_idx; - const index_t kv_head = block_indices.kv_head_idx; - const index_t q_row_base = q_origin.at(number<0>{}); + const index_t qo_head = block_indices.qo_head_idx; + const index_t kv_head = block_indices.kv_head_idx; + const index_t q_row_base = q_origin.at(number<0>{}); // Number of distinct pages this tile spans. - // page_size >= kN0 -> 1 (fast path, identical to original behavior) + // page_size >= kN0 -> 1 (fast path: single page covers the tile) // page_size < kN0 -> kN0 / page_size (cross-page tile) constexpr index_t kPagesPerTile = (kPageBlockSize >= kN0) ? 1 : (kN0 / kPageBlockSize); @@ -1140,12 +1175,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync index_t tile_k_pages[kPagesPerTile]; if constexpr(kPagesPerTile == 1) { - // Single-page tile: reuse the page already loaded for K-gemm. + // Single-page tile: reuse the page already resolved for K-gemm. tile_k_pages[0] = k_physical_pages[number<0>{}]; } else { - const index_t k_origin_n = k_origin.at(number<0>{}); + // Only read k_dram_block_window origin in the cross-page case + // where we actually need it (kPagesPerTile == 1 already has the + // page via k_physical_pages, and avoiding the window-origin read + // here keeps the K SRD off VGPRs on the fast path -- the same + // discipline the fmha_fwd qr_async PER_TOKEN_HEAD branch uses). + const index_t k_origin_n = + k_dram_block_window.get_window_origin().at(number<0>{}); static_for<0, kPagesPerTile, 1>{}([&](auto p) { const index_t gp = (k_origin_n + p.value * kPageBlockSize) >> kLog2PageBlockSize; @@ -1154,6 +1195,44 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync }); } + // LDS staging tiles. Per-block allocations sized to the descale + // working set (kM0 + kN0 fp32 = 1024 B for the d128 tile). + __shared__ float lds_q_descale[kM0]; + __shared__ float lds_k_descale[kN0]; + + const index_t tid_in_block = + static_cast(threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y); + const index_t threads_per_block = + static_cast(blockDim.x * blockDim.y * blockDim.z); + + __builtin_amdgcn_sched_barrier(0); + + // Q-row descales (kM0 entries). + for(index_t off = tid_in_block; off < kM0; off += threads_per_block) + { + lds_q_descale[off] = q_descale_per_token_ptr[ + (q_row_base + off) * stride_q_descale_token + + qo_head * nhead_stride_q_descale]; + } + + // K-col descales (kN0 entries). + // Fast path (kPagesPerTile == 1): k_page folds to tile_k_pages[0] + // so the inner address is a single stride product over k_slot. + // Cross-page path: k_page switches every kPageBlockSize columns. + for(index_t off = tid_in_block; off < kN0; off += threads_per_block) + { + const index_t k_page = tile_k_pages[ + (kPagesPerTile == 1) ? index_t{0} + : (off >> kLog2PageBlockSize)]; + const index_t k_slot = off & kPageSlotMask; + lds_k_descale[off] = k_descale_ptr[ + k_page * nblock_stride_k_descale_page + + kv_head * nhead_stride_k_descale + + k_slot * stride_k_descale_token]; + } + __builtin_amdgcn_s_barrier(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { @@ -1162,22 +1241,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t i = tile_idx.at(number<0>{}); const index_t j = tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - const float qd = q_descale_per_token_ptr[ - (q_row_base + i) * stride_q_descale_token + - qo_head * nhead_stride_q_descale]; - // Per-column page + slot. For kPagesPerTile==1 the - // selector folds to 0 at compile time. - const index_t k_page = tile_k_pages[ - (kPagesPerTile == 1) ? index_t{0} - : (j >> kLog2PageBlockSize)]; - const index_t k_slot = j & kPageSlotMask; - const float kd = k_descale_ptr[ - k_page * nblock_stride_k_descale_page + - kv_head * nhead_stride_k_descale + - k_slot * stride_k_descale_token]; - s_acc(i_j_idx) *= qd * kd; + s_acc(i_j_idx) *= lds_q_descale[i] * lds_k_descale[j]; }); }); + __builtin_amdgcn_sched_barrier(0); } const auto p = [&]() {