Fixing perf regression by splitting num_blocks_start arithmetic into SWA/non-SWA path

This commit is contained in:
Damien Lejeune
2026-05-08 14:01:15 +00:00
parent f438cef286
commit e36693c4dc

View File

@@ -361,21 +361,43 @@ struct UnifiedAttentionPipeline
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
assert(k_block_idx == v_block_idx); // because of the following line
// num_blocks_start is in kPageBlockSize sub-block units, but block_table_offset
// (and block_tables_ptr) are in PageSize-page units. Convert correctly so that
// a sub-block-aligned start position lands on the right page AND the right
// sub-block within that page. This matters for SWA (tile_lo > 0) and for
// split-KV when kv_page_size_in_blocks > 1.
const index_t page_advance = num_blocks_start / kv_page_size_in_blocks;
const index_t init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks;
block_table_offset += page_advance;
// k_block_idx counts sub-blocks relative to the first iterated page. Starting it
// at init_sub_block_offset lets the existing K_mem_load math (k_block_idx /
// kv_page_size_in_blocks for the page index, k_block_idx % kv_page_size_in_blocks
// for the within-page sub-block) keep working unchanged.
k_block_idx = init_sub_block_offset;
v_block_idx = init_sub_block_offset;
index_t kv_blk_idx_initial =
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
// (and block_tables_ptr) are in PageSize-page units. SWA can produce a
// non-page-aligned start (per-Q-tile lower bound from mask.GetTileRangeAlongX
// is sub-block-aligned), so we must split it into a page advance + a within-page
// sub-block offset. Non-SWA callers always pass a page-aligned num_blocks_start
// (== 0 for non-split-KV; == i_split * blocks_per_split for split-KV with
// page-aligned splits), so we keep the original single-add fast path for them
// to avoid emitting runtime integer divisions on the hot decode path — the
// divisions cost ~1% per kernel invocation and were the entire source of the
// post-SWA decode-bench regression.
index_t init_sub_block_offset = 0;
if constexpr(FmhaMask::IsLocal)
{
const index_t page_advance = num_blocks_start / kv_page_size_in_blocks;
init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks;
block_table_offset += page_advance;
// k_block_idx counts sub-blocks relative to the first iterated page.
// Starting it at init_sub_block_offset lets the existing K_mem_load math
// (k_block_idx / kv_page_size_in_blocks for page index, k_block_idx %
// kv_page_size_in_blocks for within-page sub-block) keep working unchanged.
k_block_idx = init_sub_block_offset;
v_block_idx = init_sub_block_offset;
}
else
{
// Original behavior: assume num_blocks_start is page-aligned. If a future
// caller violates this with kv_page_size_in_blocks > 1 the index below
// mis-points; the assert catches it in debug builds.
assert(kv_page_size_in_blocks == 1 ||
num_blocks_start % kv_page_size_in_blocks == 0);
block_table_offset += num_blocks_start;
}
// After both branches above k_block_idx is in [0, kv_page_size_in_blocks),
// so k_block_idx / kv_page_size_in_blocks == 0. The initial page lookup is
// therefore just block_table_offset; this avoids an extra runtime division
// here on the non-SWA path (compiler can't prove the bound from k_block_idx
// alone).
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset];
// When row strides are provided, use pointer rebasing to avoid int32 overflow
// in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8).
@@ -390,25 +412,27 @@ struct UnifiedAttentionPipeline
auto* k_base_ptr = k_view.buf_.p_data_;
auto* v_base_ptr = v_view.buf_.p_data_;
// Within-page byte offset (in K/V rows) for the very first iterated sub-block.
// Non-zero whenever num_blocks_start is not a multiple of kv_page_size_in_blocks.
const index_t initial_intra_page_row_offset =
init_sub_block_offset * kPageBlockSize;
// Within-page row offset for the very first iterated sub-block. Only declared
// on the SWA path: keeping the variable out of the non-SWA codegen guarantees
// bit-identical pre-SWA fast-path arithmetic for k_off/v_off/init_origin and
// avoids relying on the optimizer to fold a 0 across the long_index_t mults.
if(use_ptr_rebase)
{
long_index_t k_off =
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
static_cast<long_index_t>(initial_intra_page_row_offset)) *
k_row_stride;
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
long_index_t v_off =
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
if constexpr(FmhaMask::IsLocal)
{
const long_index_t intra =
static_cast<long_index_t>(init_sub_block_offset) * kPageBlockSize;
k_off += intra * k_row_stride;
v_off += intra * v_row_stride;
}
k_view.buf_.p_data_ = k_base_ptr + k_off;
auto new_k = k_view.buf_.buffer_size_ - k_off;
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
long_index_t v_off =
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
static_cast<long_index_t>(initial_intra_page_row_offset)) *
v_row_stride;
v_view.buf_.p_data_ = v_base_ptr + v_off;
auto new_v = v_view.buf_.buffer_size_ - v_off;
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
@@ -420,10 +444,14 @@ struct UnifiedAttentionPipeline
v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
}
const index_t init_origin = use_ptr_rebase
? 0
: kv_blk_idx_initial * PageSize +
initial_intra_page_row_offset;
// Match the pre-SWA single-ternary form for the non-SWA path (no lambda — keeps
// the optimizer's job trivial). SWA adds the within-page offset on top.
index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
if constexpr(FmhaMask::IsLocal)
{
if(!use_ptr_rebase)
init_origin += init_sub_block_offset * kPageBlockSize;
}
auto k_dram_window =
make_tile_window(k_view,
@@ -1179,14 +1207,22 @@ struct UnifiedAttentionPipeline
// (= num_total_loop - num_blocks_start), not by num_total_loop alone. Keying
// on num_total_loop matches when num_blocks_start is even but reads
// uninitialised sp(1)/V[1] when it is odd, which silently corrupts o_acc.
// This previously only mattered for split-KV with an odd split boundary; SWA
// exposes it whenever the per-Q-tile lower-bound clip is odd.
const index_t num_iters = num_total_loop - num_blocks_start;
if(num_iters % 2)
// SWA needs this because its per-Q-tile lower-bound clip can be odd. Non-SWA
// callers always pass an even num_blocks_start (== 0 for non-split-KV; even
// multiples for split-KV with kBlockN-aligned splits), so for those the
// simpler num_total_loop parity is correct AND avoids an extra subtract on
// the hot exit path. Gating restores bit-identical pre-SWA codegen for the
// non-SWA fast path (the entire reason the post-SWA decode bench regressed).
index_t parity_n;
if constexpr(FmhaMask::IsLocal)
parity_n = num_total_loop - num_blocks_start;
else
parity_n = num_total_loop;
if(parity_n % 2)
{
fmha_post_process(number<1>{});
}
if(!(num_iters % 2))
if(!(parity_n % 2))
{
fmha_post_process(number<0>{});
}