optimize batch_prefill pipeline based on fwd one

This commit is contained in:
msaffari-amd
2026-05-27 15:59:40 +00:00
parent b5cd209196
commit 28814b4cfc

View File

@@ -663,6 +663,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
k_dram_window.init_raw();
// Track the K/V SRD's current physical page so consecutive same-page rebases
// (the common case when kPageBlockSize >> kN0, e.g. page_size=1024 / kN0=128
// gives 8 same-page tiles in a row) collapse to a no-op. -1 is a never-valid
// sentinel that forces the first rebase to run.
index_t last_k_page = static_cast<index_t>(-1);
index_t last_v_page = static_cast<index_t>(-1);
// SRD rebasing for K: only for page_size >= kN0 (all threads on same page).
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
// addressing.
@@ -671,7 +678,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
{
// readfirstlane: make physical_page provably wave-uniform so the
// resulting SRD lands in SGPRs (required by buffer load instructions).
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
// Skip the SRD reset (data ptr + num_records + init_raw) when the
// target page is the same as the one currently encoded in the SRD.
// last_k_page is also wave-uniform (initialized to -1 and only ever
// assigned from readfirstlane'd values), so the branch is wave-uniform
// and the compiler can keep last_k_page in SGPRs.
if(physical_page == last_k_page)
return;
last_k_page = physical_page;
const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_;
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
@@ -695,6 +710,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// readfirstlane: make physical_page provably wave-uniform so the
// resulting SRD lands in SGPRs (required by buffer load instructions).
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
// Same same-page-skip trick as rebase_k_window above. The V SRD is
// rebased multiple times per K-loop iter (initial setup, post-GEMM0,
// sink boundary, next-iter prep); for large page_size all of those
// typically stay on the same page.
if(physical_page == last_v_page)
return;
last_v_page = physical_page;
const auto* base_ptr =
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
const auto* page_ptr =
@@ -1113,17 +1135,30 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
// PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale.
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page(j), k_slot(j), kv_head]
// Supports cross-page tiles (kPageBlockSize < kN0): column j is looked up in the
// page covering token (k_origin + j).
//
// Implementation notes:
// The naive form (one global load of qd + kd inside the per-element
// s_acc sweep) has two problems on gfx9 qr_async: (a) it inflates
// the inner-loop instruction footprint with multi-component address
// arithmetic that the compiler must keep alive per element, and (b)
// it puts the K SRD under SGPR pressure (same class of issue we hit
// in fmha_fwd qr_async). To dodge both we stage Q-row and K-col
// descales through LDS once per K-loop iteration; the per-element
// sweep then collapses to a pure LDS read + FP multiply, identical
// to the fmha_fwd qr_async PER_TOKEN_HEAD path
// (block_fmha_pipeline_qr_ks_vs_async.hpp).
//
// Supports cross-page tiles (kPageBlockSize < kN0): each LDS-load
// thread resolves its page index from `tile_k_pages` (1 entry on
// the fast path, kN0/kPageBlockSize entries on the cross-page path).
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
{
const auto k_origin = k_dram_block_window.get_window_origin();
const index_t qo_head = block_indices.qo_head_idx;
const index_t kv_head = block_indices.kv_head_idx;
const index_t q_row_base = q_origin.at(number<0>{});
const index_t qo_head = block_indices.qo_head_idx;
const index_t kv_head = block_indices.kv_head_idx;
const index_t q_row_base = q_origin.at(number<0>{});
// Number of distinct pages this tile spans.
// page_size >= kN0 -> 1 (fast path, identical to original behavior)
// page_size >= kN0 -> 1 (fast path: single page covers the tile)
// page_size < kN0 -> kN0 / page_size (cross-page tile)
constexpr index_t kPagesPerTile =
(kPageBlockSize >= kN0) ? 1 : (kN0 / kPageBlockSize);
@@ -1140,12 +1175,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
index_t tile_k_pages[kPagesPerTile];
if constexpr(kPagesPerTile == 1)
{
// Single-page tile: reuse the page already loaded for K-gemm.
// Single-page tile: reuse the page already resolved for K-gemm.
tile_k_pages[0] = k_physical_pages[number<0>{}];
}
else
{
const index_t k_origin_n = k_origin.at(number<0>{});
// Only read k_dram_block_window origin in the cross-page case
// where we actually need it (kPagesPerTile == 1 already has the
// page via k_physical_pages, and avoiding the window-origin read
// here keeps the K SRD off VGPRs on the fast path -- the same
// discipline the fmha_fwd qr_async PER_TOKEN_HEAD branch uses).
const index_t k_origin_n =
k_dram_block_window.get_window_origin().at(number<0>{});
static_for<0, kPagesPerTile, 1>{}([&](auto p) {
const index_t gp = (k_origin_n + p.value * kPageBlockSize)
>> kLog2PageBlockSize;
@@ -1154,6 +1195,44 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
});
}
// LDS staging tiles. Per-block allocations sized to the descale
// working set (kM0 + kN0 fp32 = 1024 B for the d128 tile).
__shared__ float lds_q_descale[kM0];
__shared__ float lds_k_descale[kN0];
const index_t tid_in_block =
static_cast<index_t>(threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * blockDim.x * blockDim.y);
const index_t threads_per_block =
static_cast<index_t>(blockDim.x * blockDim.y * blockDim.z);
__builtin_amdgcn_sched_barrier(0);
// Q-row descales (kM0 entries).
for(index_t off = tid_in_block; off < kM0; off += threads_per_block)
{
lds_q_descale[off] = q_descale_per_token_ptr[
(q_row_base + off) * stride_q_descale_token +
qo_head * nhead_stride_q_descale];
}
// K-col descales (kN0 entries).
// Fast path (kPagesPerTile == 1): k_page folds to tile_k_pages[0]
// so the inner address is a single stride product over k_slot.
// Cross-page path: k_page switches every kPageBlockSize columns.
for(index_t off = tid_in_block; off < kN0; off += threads_per_block)
{
const index_t k_page = tile_k_pages[
(kPagesPerTile == 1) ? index_t{0}
: (off >> kLog2PageBlockSize)];
const index_t k_slot = off & kPageSlotMask;
lds_k_descale[off] = k_descale_ptr[
k_page * nblock_stride_k_descale_page +
kv_head * nhead_stride_k_descale +
k_slot * stride_k_descale_token];
}
__builtin_amdgcn_s_barrier();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
@@ -1162,22 +1241,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const index_t i = tile_idx.at(number<0>{});
const index_t j = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const float qd = q_descale_per_token_ptr[
(q_row_base + i) * stride_q_descale_token +
qo_head * nhead_stride_q_descale];
// Per-column page + slot. For kPagesPerTile==1 the
// selector folds to 0 at compile time.
const index_t k_page = tile_k_pages[
(kPagesPerTile == 1) ? index_t{0}
: (j >> kLog2PageBlockSize)];
const index_t k_slot = j & kPageSlotMask;
const float kd = k_descale_ptr[
k_page * nblock_stride_k_descale_page +
kv_head * nhead_stride_k_descale +
k_slot * stride_k_descale_token];
s_acc(i_j_idx) *= qd * kd;
s_acc(i_j_idx) *= lds_q_descale[i] * lds_k_descale[j];
});
});
__builtin_amdgcn_sched_barrier(0);
}
const auto p = [&]() {