mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Fixing perf regression by splitting num_blocks_start arithmetic into SWA/non-SWA path
This commit is contained in:
@@ -361,21 +361,43 @@ struct UnifiedAttentionPipeline
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
assert(k_block_idx == v_block_idx); // because of the following line
|
||||
// num_blocks_start is in kPageBlockSize sub-block units, but block_table_offset
|
||||
// (and block_tables_ptr) are in PageSize-page units. Convert correctly so that
|
||||
// a sub-block-aligned start position lands on the right page AND the right
|
||||
// sub-block within that page. This matters for SWA (tile_lo > 0) and for
|
||||
// split-KV when kv_page_size_in_blocks > 1.
|
||||
const index_t page_advance = num_blocks_start / kv_page_size_in_blocks;
|
||||
const index_t init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks;
|
||||
block_table_offset += page_advance;
|
||||
// k_block_idx counts sub-blocks relative to the first iterated page. Starting it
|
||||
// at init_sub_block_offset lets the existing K_mem_load math (k_block_idx /
|
||||
// kv_page_size_in_blocks for the page index, k_block_idx % kv_page_size_in_blocks
|
||||
// for the within-page sub-block) keep working unchanged.
|
||||
k_block_idx = init_sub_block_offset;
|
||||
v_block_idx = init_sub_block_offset;
|
||||
index_t kv_blk_idx_initial =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
// (and block_tables_ptr) are in PageSize-page units. SWA can produce a
|
||||
// non-page-aligned start (per-Q-tile lower bound from mask.GetTileRangeAlongX
|
||||
// is sub-block-aligned), so we must split it into a page advance + a within-page
|
||||
// sub-block offset. Non-SWA callers always pass a page-aligned num_blocks_start
|
||||
// (== 0 for non-split-KV; == i_split * blocks_per_split for split-KV with
|
||||
// page-aligned splits), so we keep the original single-add fast path for them
|
||||
// to avoid emitting runtime integer divisions on the hot decode path — the
|
||||
// divisions cost ~1% per kernel invocation and were the entire source of the
|
||||
// post-SWA decode-bench regression.
|
||||
index_t init_sub_block_offset = 0;
|
||||
if constexpr(FmhaMask::IsLocal)
|
||||
{
|
||||
const index_t page_advance = num_blocks_start / kv_page_size_in_blocks;
|
||||
init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks;
|
||||
block_table_offset += page_advance;
|
||||
// k_block_idx counts sub-blocks relative to the first iterated page.
|
||||
// Starting it at init_sub_block_offset lets the existing K_mem_load math
|
||||
// (k_block_idx / kv_page_size_in_blocks for page index, k_block_idx %
|
||||
// kv_page_size_in_blocks for within-page sub-block) keep working unchanged.
|
||||
k_block_idx = init_sub_block_offset;
|
||||
v_block_idx = init_sub_block_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Original behavior: assume num_blocks_start is page-aligned. If a future
|
||||
// caller violates this with kv_page_size_in_blocks > 1 the index below
|
||||
// mis-points; the assert catches it in debug builds.
|
||||
assert(kv_page_size_in_blocks == 1 ||
|
||||
num_blocks_start % kv_page_size_in_blocks == 0);
|
||||
block_table_offset += num_blocks_start;
|
||||
}
|
||||
// After both branches above k_block_idx is in [0, kv_page_size_in_blocks),
|
||||
// so k_block_idx / kv_page_size_in_blocks == 0. The initial page lookup is
|
||||
// therefore just block_table_offset; this avoids an extra runtime division
|
||||
// here on the non-SWA path (compiler can't prove the bound from k_block_idx
|
||||
// alone).
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset];
|
||||
|
||||
// When row strides are provided, use pointer rebasing to avoid int32 overflow
|
||||
// in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8).
|
||||
@@ -390,25 +412,27 @@ struct UnifiedAttentionPipeline
|
||||
auto* k_base_ptr = k_view.buf_.p_data_;
|
||||
auto* v_base_ptr = v_view.buf_.p_data_;
|
||||
|
||||
// Within-page byte offset (in K/V rows) for the very first iterated sub-block.
|
||||
// Non-zero whenever num_blocks_start is not a multiple of kv_page_size_in_blocks.
|
||||
const index_t initial_intra_page_row_offset =
|
||||
init_sub_block_offset * kPageBlockSize;
|
||||
|
||||
// Within-page row offset for the very first iterated sub-block. Only declared
|
||||
// on the SWA path: keeping the variable out of the non-SWA codegen guarantees
|
||||
// bit-identical pre-SWA fast-path arithmetic for k_off/v_off/init_origin and
|
||||
// avoids relying on the optimizer to fold a 0 across the long_index_t mults.
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
long_index_t k_off =
|
||||
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
|
||||
static_cast<long_index_t>(initial_intra_page_row_offset)) *
|
||||
k_row_stride;
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
|
||||
long_index_t v_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
|
||||
if constexpr(FmhaMask::IsLocal)
|
||||
{
|
||||
const long_index_t intra =
|
||||
static_cast<long_index_t>(init_sub_block_offset) * kPageBlockSize;
|
||||
k_off += intra * k_row_stride;
|
||||
v_off += intra * v_row_stride;
|
||||
}
|
||||
k_view.buf_.p_data_ = k_base_ptr + k_off;
|
||||
auto new_k = k_view.buf_.buffer_size_ - k_off;
|
||||
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
|
||||
|
||||
long_index_t v_off =
|
||||
(static_cast<long_index_t>(kv_blk_idx_initial) * PageSize +
|
||||
static_cast<long_index_t>(initial_intra_page_row_offset)) *
|
||||
v_row_stride;
|
||||
v_view.buf_.p_data_ = v_base_ptr + v_off;
|
||||
auto new_v = v_view.buf_.buffer_size_ - v_off;
|
||||
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
|
||||
@@ -420,10 +444,14 @@ struct UnifiedAttentionPipeline
|
||||
v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
}
|
||||
|
||||
const index_t init_origin = use_ptr_rebase
|
||||
? 0
|
||||
: kv_blk_idx_initial * PageSize +
|
||||
initial_intra_page_row_offset;
|
||||
// Match the pre-SWA single-ternary form for the non-SWA path (no lambda — keeps
|
||||
// the optimizer's job trivial). SWA adds the within-page offset on top.
|
||||
index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
|
||||
if constexpr(FmhaMask::IsLocal)
|
||||
{
|
||||
if(!use_ptr_rebase)
|
||||
init_origin += init_sub_block_offset * kPageBlockSize;
|
||||
}
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_view,
|
||||
@@ -1179,14 +1207,22 @@ struct UnifiedAttentionPipeline
|
||||
// (= num_total_loop - num_blocks_start), not by num_total_loop alone. Keying
|
||||
// on num_total_loop matches when num_blocks_start is even but reads
|
||||
// uninitialised sp(1)/V[1] when it is odd, which silently corrupts o_acc.
|
||||
// This previously only mattered for split-KV with an odd split boundary; SWA
|
||||
// exposes it whenever the per-Q-tile lower-bound clip is odd.
|
||||
const index_t num_iters = num_total_loop - num_blocks_start;
|
||||
if(num_iters % 2)
|
||||
// SWA needs this because its per-Q-tile lower-bound clip can be odd. Non-SWA
|
||||
// callers always pass an even num_blocks_start (== 0 for non-split-KV; even
|
||||
// multiples for split-KV with kBlockN-aligned splits), so for those the
|
||||
// simpler num_total_loop parity is correct AND avoids an extra subtract on
|
||||
// the hot exit path. Gating restores bit-identical pre-SWA codegen for the
|
||||
// non-SWA fast path (the entire reason the post-SWA decode bench regressed).
|
||||
index_t parity_n;
|
||||
if constexpr(FmhaMask::IsLocal)
|
||||
parity_n = num_total_loop - num_blocks_start;
|
||||
else
|
||||
parity_n = num_total_loop;
|
||||
if(parity_n % 2)
|
||||
{
|
||||
fmha_post_process(number<1>{});
|
||||
}
|
||||
if(!(num_iters % 2))
|
||||
if(!(parity_n % 2))
|
||||
{
|
||||
fmha_post_process(number<0>{});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user