From e36693c4dcefbcf8904d23d303f823c397ac371b Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Fri, 8 May 2026 14:01:15 +0000 Subject: [PATCH] Fixing perf regression by splitting num_blocks_start arithmetic into SWA/non-SWA path --- .../pipeline/unified_attention_pipeline.hpp | 108 ++++++++++++------ 1 file changed, 72 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 6df628a118..a7092a7a92 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -361,21 +361,43 @@ struct UnifiedAttentionPipeline reinterpret_cast(block_tables_ptr); assert(k_block_idx == v_block_idx); // because of the following line // num_blocks_start is in kPageBlockSize sub-block units, but block_table_offset - // (and block_tables_ptr) are in PageSize-page units. Convert correctly so that - // a sub-block-aligned start position lands on the right page AND the right - // sub-block within that page. This matters for SWA (tile_lo > 0) and for - // split-KV when kv_page_size_in_blocks > 1. - const index_t page_advance = num_blocks_start / kv_page_size_in_blocks; - const index_t init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks; - block_table_offset += page_advance; - // k_block_idx counts sub-blocks relative to the first iterated page. Starting it - // at init_sub_block_offset lets the existing K_mem_load math (k_block_idx / - // kv_page_size_in_blocks for the page index, k_block_idx % kv_page_size_in_blocks - // for the within-page sub-block) keep working unchanged. - k_block_idx = init_sub_block_offset; - v_block_idx = init_sub_block_offset; - index_t kv_blk_idx_initial = - block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; + // (and block_tables_ptr) are in PageSize-page units. SWA can produce a + // non-page-aligned start (per-Q-tile lower bound from mask.GetTileRangeAlongX + // is sub-block-aligned), so we must split it into a page advance + a within-page + // sub-block offset. Non-SWA callers always pass a page-aligned num_blocks_start + // (== 0 for non-split-KV; == i_split * blocks_per_split for split-KV with + // page-aligned splits), so we keep the original single-add fast path for them + // to avoid emitting runtime integer divisions on the hot decode path — the + // divisions cost ~1% per kernel invocation and were the entire source of the + // post-SWA decode-bench regression. + index_t init_sub_block_offset = 0; + if constexpr(FmhaMask::IsLocal) + { + const index_t page_advance = num_blocks_start / kv_page_size_in_blocks; + init_sub_block_offset = num_blocks_start % kv_page_size_in_blocks; + block_table_offset += page_advance; + // k_block_idx counts sub-blocks relative to the first iterated page. + // Starting it at init_sub_block_offset lets the existing K_mem_load math + // (k_block_idx / kv_page_size_in_blocks for page index, k_block_idx % + // kv_page_size_in_blocks for within-page sub-block) keep working unchanged. + k_block_idx = init_sub_block_offset; + v_block_idx = init_sub_block_offset; + } + else + { + // Original behavior: assume num_blocks_start is page-aligned. If a future + // caller violates this with kv_page_size_in_blocks > 1 the index below + // mis-points; the assert catches it in debug builds. + assert(kv_page_size_in_blocks == 1 || + num_blocks_start % kv_page_size_in_blocks == 0); + block_table_offset += num_blocks_start; + } + // After both branches above k_block_idx is in [0, kv_page_size_in_blocks), + // so k_block_idx / kv_page_size_in_blocks == 0. The initial page lookup is + // therefore just block_table_offset; this avoids an extra runtime division + // here on the non-SWA path (compiler can't prove the bound from k_block_idx + // alone). + index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset]; // When row strides are provided, use pointer rebasing to avoid int32 overflow // in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8). @@ -390,25 +412,27 @@ struct UnifiedAttentionPipeline auto* k_base_ptr = k_view.buf_.p_data_; auto* v_base_ptr = v_view.buf_.p_data_; - // Within-page byte offset (in K/V rows) for the very first iterated sub-block. - // Non-zero whenever num_blocks_start is not a multiple of kv_page_size_in_blocks. - const index_t initial_intra_page_row_offset = - init_sub_block_offset * kPageBlockSize; - + // Within-page row offset for the very first iterated sub-block. Only declared + // on the SWA path: keeping the variable out of the non-SWA codegen guarantees + // bit-identical pre-SWA fast-path arithmetic for k_off/v_off/init_origin and + // avoids relying on the optimizer to fold a 0 across the long_index_t mults. if(use_ptr_rebase) { long_index_t k_off = - (static_cast(kv_blk_idx_initial) * PageSize + - static_cast(initial_intra_page_row_offset)) * - k_row_stride; + static_cast(kv_blk_idx_initial) * PageSize * k_row_stride; + long_index_t v_off = + static_cast(kv_blk_idx_initial) * PageSize * v_row_stride; + if constexpr(FmhaMask::IsLocal) + { + const long_index_t intra = + static_cast(init_sub_block_offset) * kPageBlockSize; + k_off += intra * k_row_stride; + v_off += intra * v_row_stride; + } k_view.buf_.p_data_ = k_base_ptr + k_off; auto new_k = k_view.buf_.buffer_size_ - k_off; k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim; - long_index_t v_off = - (static_cast(kv_blk_idx_initial) * PageSize + - static_cast(initial_intra_page_row_offset)) * - v_row_stride; v_view.buf_.p_data_ = v_base_ptr + v_off; auto new_v = v_view.buf_.buffer_size_ - v_off; v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim; @@ -420,10 +444,14 @@ struct UnifiedAttentionPipeline v_view = v_dram_block_window_tmp.get_bottom_tensor_view(); } - const index_t init_origin = use_ptr_rebase - ? 0 - : kv_blk_idx_initial * PageSize + - initial_intra_page_row_offset; + // Match the pre-SWA single-ternary form for the non-SWA path (no lambda — keeps + // the optimizer's job trivial). SWA adds the within-page offset on top. + index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize; + if constexpr(FmhaMask::IsLocal) + { + if(!use_ptr_rebase) + init_origin += init_sub_block_offset * kPageBlockSize; + } auto k_dram_window = make_tile_window(k_view, @@ -1179,14 +1207,22 @@ struct UnifiedAttentionPipeline // (= num_total_loop - num_blocks_start), not by num_total_loop alone. Keying // on num_total_loop matches when num_blocks_start is even but reads // uninitialised sp(1)/V[1] when it is odd, which silently corrupts o_acc. - // This previously only mattered for split-KV with an odd split boundary; SWA - // exposes it whenever the per-Q-tile lower-bound clip is odd. - const index_t num_iters = num_total_loop - num_blocks_start; - if(num_iters % 2) + // SWA needs this because its per-Q-tile lower-bound clip can be odd. Non-SWA + // callers always pass an even num_blocks_start (== 0 for non-split-KV; even + // multiples for split-KV with kBlockN-aligned splits), so for those the + // simpler num_total_loop parity is correct AND avoids an extra subtract on + // the hot exit path. Gating restores bit-identical pre-SWA codegen for the + // non-SWA fast path (the entire reason the post-SWA decode bench regressed). + index_t parity_n; + if constexpr(FmhaMask::IsLocal) + parity_n = num_total_loop - num_blocks_start; + else + parity_n = num_total_loop; + if(parity_n % 2) { fmha_post_process(number<1>{}); } - if(!(num_iters % 2)) + if(!(parity_n % 2)) { fmha_post_process(number<0>{}); }