mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
optimize batch_prefill pipeline based on fwd one
This commit is contained in:
@@ -663,6 +663,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
k_dram_window.init_raw();
|
||||
|
||||
// Track the K/V SRD's current physical page so consecutive same-page rebases
|
||||
// (the common case when kPageBlockSize >> kN0, e.g. page_size=1024 / kN0=128
|
||||
// gives 8 same-page tiles in a row) collapse to a no-op. -1 is a never-valid
|
||||
// sentinel that forces the first rebase to run.
|
||||
index_t last_k_page = static_cast<index_t>(-1);
|
||||
index_t last_v_page = static_cast<index_t>(-1);
|
||||
|
||||
// SRD rebasing for K: only for page_size >= kN0 (all threads on same page).
|
||||
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
|
||||
// addressing.
|
||||
@@ -671,7 +678,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
// readfirstlane: make physical_page provably wave-uniform so the
|
||||
// resulting SRD lands in SGPRs (required by buffer load instructions).
|
||||
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
|
||||
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
|
||||
// Skip the SRD reset (data ptr + num_records + init_raw) when the
|
||||
// target page is the same as the one currently encoded in the SRD.
|
||||
// last_k_page is also wave-uniform (initialized to -1 and only ever
|
||||
// assigned from readfirstlane'd values), so the branch is wave-uniform
|
||||
// and the compiler can keep last_k_page in SGPRs.
|
||||
if(physical_page == last_k_page)
|
||||
return;
|
||||
last_k_page = physical_page;
|
||||
const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto* page_ptr =
|
||||
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
|
||||
@@ -695,6 +710,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// readfirstlane: make physical_page provably wave-uniform so the
|
||||
// resulting SRD lands in SGPRs (required by buffer load instructions).
|
||||
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
|
||||
// Same same-page-skip trick as rebase_k_window above. The V SRD is
|
||||
// rebased multiple times per K-loop iter (initial setup, post-GEMM0,
|
||||
// sink boundary, next-iter prep); for large page_size all of those
|
||||
// typically stay on the same page.
|
||||
if(physical_page == last_v_page)
|
||||
return;
|
||||
last_v_page = physical_page;
|
||||
const auto* base_ptr =
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
|
||||
const auto* page_ptr =
|
||||
@@ -1113,17 +1135,30 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
// PER_TOKEN_HEAD: dequantize QK result with per-row Q descale and per-column K descale.
|
||||
// s_acc[i,j] *= q_descale[q_origin+i, qo_head] * k_descale[k_page(j), k_slot(j), kv_head]
|
||||
// Supports cross-page tiles (kPageBlockSize < kN0): column j is looked up in the
|
||||
// page covering token (k_origin + j).
|
||||
//
|
||||
// Implementation notes:
|
||||
// The naive form (one global load of qd + kd inside the per-element
|
||||
// s_acc sweep) has two problems on gfx9 qr_async: (a) it inflates
|
||||
// the inner-loop instruction footprint with multi-component address
|
||||
// arithmetic that the compiler must keep alive per element, and (b)
|
||||
// it puts the K SRD under SGPR pressure (same class of issue we hit
|
||||
// in fmha_fwd qr_async). To dodge both we stage Q-row and K-col
|
||||
// descales through LDS once per K-loop iteration; the per-element
|
||||
// sweep then collapses to a pure LDS read + FP multiply, identical
|
||||
// to the fmha_fwd qr_async PER_TOKEN_HEAD path
|
||||
// (block_fmha_pipeline_qr_ks_vs_async.hpp).
|
||||
//
|
||||
// Supports cross-page tiles (kPageBlockSize < kN0): each LDS-load
|
||||
// thread resolves its page index from `tile_k_pages` (1 entry on
|
||||
// the fast path, kN0/kPageBlockSize entries on the cross-page path).
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
const index_t qo_head = block_indices.qo_head_idx;
|
||||
const index_t kv_head = block_indices.kv_head_idx;
|
||||
const index_t q_row_base = q_origin.at(number<0>{});
|
||||
const index_t qo_head = block_indices.qo_head_idx;
|
||||
const index_t kv_head = block_indices.kv_head_idx;
|
||||
const index_t q_row_base = q_origin.at(number<0>{});
|
||||
|
||||
// Number of distinct pages this tile spans.
|
||||
// page_size >= kN0 -> 1 (fast path, identical to original behavior)
|
||||
// page_size >= kN0 -> 1 (fast path: single page covers the tile)
|
||||
// page_size < kN0 -> kN0 / page_size (cross-page tile)
|
||||
constexpr index_t kPagesPerTile =
|
||||
(kPageBlockSize >= kN0) ? 1 : (kN0 / kPageBlockSize);
|
||||
@@ -1140,12 +1175,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
index_t tile_k_pages[kPagesPerTile];
|
||||
if constexpr(kPagesPerTile == 1)
|
||||
{
|
||||
// Single-page tile: reuse the page already loaded for K-gemm.
|
||||
// Single-page tile: reuse the page already resolved for K-gemm.
|
||||
tile_k_pages[0] = k_physical_pages[number<0>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t k_origin_n = k_origin.at(number<0>{});
|
||||
// Only read k_dram_block_window origin in the cross-page case
|
||||
// where we actually need it (kPagesPerTile == 1 already has the
|
||||
// page via k_physical_pages, and avoiding the window-origin read
|
||||
// here keeps the K SRD off VGPRs on the fast path -- the same
|
||||
// discipline the fmha_fwd qr_async PER_TOKEN_HEAD branch uses).
|
||||
const index_t k_origin_n =
|
||||
k_dram_block_window.get_window_origin().at(number<0>{});
|
||||
static_for<0, kPagesPerTile, 1>{}([&](auto p) {
|
||||
const index_t gp = (k_origin_n + p.value * kPageBlockSize)
|
||||
>> kLog2PageBlockSize;
|
||||
@@ -1154,6 +1195,44 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
});
|
||||
}
|
||||
|
||||
// LDS staging tiles. Per-block allocations sized to the descale
|
||||
// working set (kM0 + kN0 fp32 = 1024 B for the d128 tile).
|
||||
__shared__ float lds_q_descale[kM0];
|
||||
__shared__ float lds_k_descale[kN0];
|
||||
|
||||
const index_t tid_in_block =
|
||||
static_cast<index_t>(threadIdx.x + threadIdx.y * blockDim.x +
|
||||
threadIdx.z * blockDim.x * blockDim.y);
|
||||
const index_t threads_per_block =
|
||||
static_cast<index_t>(blockDim.x * blockDim.y * blockDim.z);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Q-row descales (kM0 entries).
|
||||
for(index_t off = tid_in_block; off < kM0; off += threads_per_block)
|
||||
{
|
||||
lds_q_descale[off] = q_descale_per_token_ptr[
|
||||
(q_row_base + off) * stride_q_descale_token +
|
||||
qo_head * nhead_stride_q_descale];
|
||||
}
|
||||
|
||||
// K-col descales (kN0 entries).
|
||||
// Fast path (kPagesPerTile == 1): k_page folds to tile_k_pages[0]
|
||||
// so the inner address is a single stride product over k_slot.
|
||||
// Cross-page path: k_page switches every kPageBlockSize columns.
|
||||
for(index_t off = tid_in_block; off < kN0; off += threads_per_block)
|
||||
{
|
||||
const index_t k_page = tile_k_pages[
|
||||
(kPagesPerTile == 1) ? index_t{0}
|
||||
: (off >> kLog2PageBlockSize)];
|
||||
const index_t k_slot = off & kPageSlotMask;
|
||||
lds_k_descale[off] = k_descale_ptr[
|
||||
k_page * nblock_stride_k_descale_page +
|
||||
kv_head * nhead_stride_k_descale +
|
||||
k_slot * stride_k_descale_token];
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
@@ -1162,22 +1241,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t i = tile_idx.at(number<0>{});
|
||||
const index_t j = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const float qd = q_descale_per_token_ptr[
|
||||
(q_row_base + i) * stride_q_descale_token +
|
||||
qo_head * nhead_stride_q_descale];
|
||||
// Per-column page + slot. For kPagesPerTile==1 the
|
||||
// selector folds to 0 at compile time.
|
||||
const index_t k_page = tile_k_pages[
|
||||
(kPagesPerTile == 1) ? index_t{0}
|
||||
: (j >> kLog2PageBlockSize)];
|
||||
const index_t k_slot = j & kPageSlotMask;
|
||||
const float kd = k_descale_ptr[
|
||||
k_page * nblock_stride_k_descale_page +
|
||||
kv_head * nhead_stride_k_descale +
|
||||
k_slot * stride_k_descale_token];
|
||||
s_acc(i_j_idx) *= qd * kd;
|
||||
s_acc(i_j_idx) *= lds_q_descale[i] * lds_k_descale[j];
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
|
||||
Reference in New Issue
Block a user