From 473869aba59606092be537f56516d12edab063c1 Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Mon, 11 May 2026 10:04:01 +0000 Subject: [PATCH] Lift kPageBlockSize <= page_size constraint in CK-UA pipeline Refactor the K/V DRAM access in the unified-attention pipeline to use tile_scatter_gather with a unified per-(thread, Y0-iter) page-offset formula: logical_token = tile_idx * kPageBlockSize + thread_N_pos + i * Y0_step_N logical_page = logical_token / page_size within_page = logical_token % page_size phys_page = block_tables[block_table_offset + logical_page] page_offsets[i] = (phys_page * page_size + within_page) * row_stride The page indirection now lives entirely in page_offsets, refreshed via update_page_idx() between iters. The per-iter SRD rebase (set_bottom_tensor_view_data_ptr + init_raw) and the use_ptr_rebase overflow heuristic are gone. Effects: - The assertion kv_page_size_in_blocks >= 1 (i.e. kPageBlockSize <= page_size) in the kernel is dropped. Tiles may now span multiple cache pages, as long as Y0_step_N (= N1*N2 from the K/V tile dist) divides page_size so that a wave-wide load never straddles a page. - Pipeline arg renamed kv_page_size_in_blocks -> page_size (PageSize in tokens). Kernel passes kargs.page_size through directly. - Validated correctness vs Triton on bf16 / d=64 / decode_s with block_size in {16, 32, 64}; max abs diff 1.22e-04 in all cases. Perf is on par with the prior pass-1 scaffolding (~3.6 ms on the 131072-context shape). TODO(overflow): page_offsets are index_t; caches whose num_blocks * page_size * row_stride exceeds INT32_MAX will wrap. A future change should plumb long_index_t through the scatter-gather load path or compute a per-batch min-page shift in a pre-pass. TODO(unsupported regime): page_size < Y0_step_N (a wave crosses a page mid-iter) needs per-lane VGPR SRDs and is not implemented. Co-authored-by: Cursor --- .../kernel/unified_attention_kernel.hpp | 12 +- .../pipeline/unified_attention_pipeline.hpp | 192 +++++++++--------- 2 files changed, 108 insertions(+), 96 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 087a8872b9..43a9142175 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -307,7 +307,7 @@ struct UnifiedAttentionKernel const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); index_t _max_seq_prefix_len = amd_wave_read_first_lane( - (context_len + q_block_local_idx * kBlockQ + (kBlockM - 1) + 1)); + (context_len + q_block_local_idx * kBlockQ + (kBlockQ - 1) + 1)); if(seq_len < _max_seq_prefix_len) { @@ -454,8 +454,12 @@ struct UnifiedAttentionKernel return FmhaMask{cur_batch_query_len, seq_len}; }(); - const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize; - assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size + // Pass-2: the pipeline now uses a unified per-(thread, Y0-iter) page + // offset formula and accepts page_size in tokens directly. The earlier + // `kPageBlockSize <= page_size` constraint (which required at least one + // kernel tile to fit in a cache page) is gone — tiles may span multiple + // pages as long as the inner-N step (Y0_step_N from the K/V tile dist) + // divides page_size cleanly. auto o_acc_tile = [&]() { return UnifiedAttentionPipeline{}(q_dram_window, @@ -465,7 +469,7 @@ struct UnifiedAttentionKernel num_blocks_start, kargs.block_tables_ptr, block_table_offset, - kv_page_size_in_blocks, + kargs.page_size, mask, kargs.scale_s, smem_ptr, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 29617948df..98cf70914f 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -181,7 +181,7 @@ struct UnifiedAttentionPipeline const index_t num_blocks_start, const void* block_tables_ptr, index_t block_table_offset, - const index_t kv_page_size_in_blocks, + const index_t page_size, // PageSize in tokens (cache rows per page) [[maybe_unused]] const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, @@ -355,62 +355,109 @@ struct UnifiedAttentionPipeline } } - index_t i_total_loops = num_blocks_start; - const index_t PageSize = kv_page_size_in_blocks * kPageBlockSize; + index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); assert(k_block_idx == v_block_idx); // because of the following line block_table_offset += num_blocks_start; - index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx]; - // When row strides are provided, use pointer rebasing to avoid int32 overflow - // in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8). - // When strides are 0 (legacy callers), use the original set_window_origin approach. - // Use pointer rebasing to avoid int32 overflow in tensor_coordinate for large KV pools. - // Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs. - // hdim=128 configs have different buffer_view internals that cause issues with rebasing. - const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64); + // Pass-2: unified page-offset formula. The kPageBlockSize <= page_size + // constraint is gone. For every (thread, Y0-iter) pair we compute: + // + // logical_token = tile_idx * kPageBlockSize + // + thread_N_pos // lane/warp partition + // + i * Y0_step_N // per-Y0-iter advance + // logical_page = logical_token / page_size // index into block_tables + // within_page = logical_token % page_size // row inside the page + // phys_page = block_tables[block_table_offset + logical_page] + // page_offsets[i] = (phys_page * page_size + within_page) * row_stride + // + // The page indirection moves entirely into page_offsets, so the per-iter + // SRD rebase (set_bottom_tensor_view_data_ptr + init_raw) is dropped — + // we just call update_page_idx() to refresh offsets between tiles. This + // works for any (kPageBlockSize, page_size) pair where Y0_step_N (= the + // inner N stride from the dist encoding, N1 * N2) divides page_size, so + // a single wave-wide load instruction never straddles a page boundary. + // If page_size < Y0_step_N, per-lane VGPR SRDs would be required and we + // don't currently support that. + // + // TODO(overflow): page_offsets are index_t (int32). For caches whose + // num_blocks * page_size * row_stride exceeds INT32_MAX, the offsets + // wrap and reads return wrong data. The previous pass had a one-shot + // base-pointer shift heuristic for this case (`use_ptr_rebase`); it has + // been removed here because it does not interact well with the unified + // formula when block_tables are non-monotonic (a far-away page produces + // a large negative relative offset that the HW OOB check clamps to 0). + // A robust fix would either plumb long_index_t through the gather load + // path or compute a per-batch min-page shift in a pre-pass. + const auto k_dist = Policy::template MakeKDramTileDistribution(); + const auto v_dist = Policy::template MakeVDramTileDistribution(); + using KDstrType = decltype(k_dist); + using VDstrType = decltype(v_dist); + constexpr index_t KNRepeat = + KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}]; + constexpr index_t VNRepeat = + VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}]; + constexpr index_t KY0_step_N = + KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] * + KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}]; + constexpr index_t VY0_step_N = + VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] * + VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}]; + + const auto k_thread_coord = k_dist.calculate_index(); + const auto v_thread_coord = v_dist.calculate_index(); + const index_t k_thread_n_pos = k_thread_coord[number<0>{}]; + const index_t v_thread_n_pos = v_thread_coord[number<0>{}]; + + statically_indexed_array k_page_offsets; + statically_indexed_array v_page_offsets; + + auto refresh_k_offsets = [&](index_t k_tile_idx) { + static_for<0, KNRepeat, 1>{}([&](auto i) { + const index_t logical_token = k_tile_idx * kPageBlockSize + k_thread_n_pos + + static_cast(i.value) * KY0_step_N; + const index_t logical_page = logical_token / page_size; + const index_t within_page = logical_token - logical_page * page_size; + const index_t phys_page = + block_tables_ptr_[block_table_offset + logical_page]; + k_page_offsets(i) = + (phys_page * page_size + within_page) * k_row_stride; + }); + }; + auto refresh_v_offsets = [&](index_t v_tile_idx) { + static_for<0, VNRepeat, 1>{}([&](auto i) { + const index_t logical_token = v_tile_idx * kPageBlockSize + v_thread_n_pos + + static_cast(i.value) * VY0_step_N; + const index_t logical_page = logical_token / page_size; + const index_t within_page = logical_token - logical_page * page_size; + const index_t phys_page = + block_tables_ptr_[block_table_offset + logical_page]; + v_page_offsets(i) = + (phys_page * page_size + within_page) * v_row_stride; + }); + }; + + refresh_k_offsets(k_block_idx); + refresh_v_offsets(v_block_idx); auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view(); auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view(); - auto* k_base_ptr = k_view.buf_.p_data_; - auto* v_base_ptr = v_view.buf_.p_data_; - - if(use_ptr_rebase) - { - long_index_t k_off = - static_cast(kv_blk_idx_initial) * PageSize * k_row_stride; - k_view.buf_.p_data_ = k_base_ptr + k_off; - auto new_k = k_view.buf_.buffer_size_ - k_off; - k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim; - - long_index_t v_off = - static_cast(kv_blk_idx_initial) * PageSize * v_row_stride; - v_view.buf_.p_data_ = v_base_ptr + v_off; - auto new_v = v_view.buf_.buffer_size_ - v_off; - v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim; - } - else - { - // Legacy path: use original view with absolute window origin - k_view = k_dram_block_window_tmp.get_bottom_tensor_view(); - v_view = v_dram_block_window_tmp.get_bottom_tensor_view(); - } - - const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize; auto k_dram_window = - make_tile_window(k_view, - k_dram_block_window_tmp.get_window_lengths(), - {init_origin, 0}, - Policy::template MakeKDramTileDistribution()); + make_tile_scatter_gather(k_view, + k_dram_block_window_tmp.get_window_lengths(), + {0, 0}, + k_dist, + k_page_offsets); k_dram_window.init_raw(); auto v_dram_window = - make_tile_window(v_view, - v_dram_block_window_tmp.get_window_lengths(), - {init_origin, 0}, - Policy::template MakeVDramTileDistribution()); + make_tile_scatter_gather(v_view, + v_dram_block_window_tmp.get_window_lengths(), + {0, 0}, + v_dist, + v_page_offsets); v_dram_window.init_raw(); // prefetch K tile @@ -508,60 +555,21 @@ struct UnifiedAttentionPipeline // only for block 0 and thread if(blockIdx.x == 0 && threadIdx.x == 0) {} - auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset, - auto buf_size_orig) { - window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset; - auto new_size = buf_size_orig - elem_offset; - window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim; - window.init_raw(); - window.set_window_origin({0, 0}); - }; - - const auto k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_; - const auto v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_; - + // Pass-2: page indirection lives in page_offsets, not in the SRD. We + // refresh the per-iter offsets table and push it to the window via + // update_page_idx(); the SRD itself stays put (no init_raw per iter). auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); k_block_idx++; - - index_t k_page_blk_idx = - block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; - if(use_ptr_rebase) - { - long_index_t k_row = - static_cast(k_page_blk_idx) * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig); - } - else - { - k_dram_window.set_window_origin( - {k_page_blk_idx * PageSize + - (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); - } + refresh_k_offsets(k_block_idx); + k_dram_window.update_page_idx(k_page_offsets); }; auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); v_block_idx++; - - index_t v_page_blk_idx = - block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; - if(use_ptr_rebase) - { - long_index_t v_row = - static_cast(v_page_blk_idx) * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize; - rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig); - } - else - { - v_dram_window.set_window_origin( - {v_page_blk_idx * PageSize + - (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, - 0}); - } + refresh_v_offsets(v_block_idx); + v_dram_window.update_page_idx(v_page_offsets); }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -1191,7 +1199,7 @@ struct UnifiedAttentionPipeline const index_t num_blocks_start, const void* block_tables_ptr, index_t block_table_offset, - const index_t kv_page_size_in_blocks, + const index_t page_size, // PageSize in tokens (cache rows per page) FmhaMask mask, float scale_s, void* smem_ptr, @@ -1210,7 +1218,7 @@ struct UnifiedAttentionPipeline num_blocks_start, block_tables_ptr, block_table_offset, - kv_page_size_in_blocks, + page_size, identity{}, identity{}, identity{},