mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Lift kPageBlockSize <= page_size constraint in CK-UA pipeline
Refactor the K/V DRAM access in the unified-attention pipeline to use
tile_scatter_gather with a unified per-(thread, Y0-iter) page-offset
formula:
logical_token = tile_idx * kPageBlockSize + thread_N_pos + i * Y0_step_N
logical_page = logical_token / page_size
within_page = logical_token % page_size
phys_page = block_tables[block_table_offset + logical_page]
page_offsets[i] = (phys_page * page_size + within_page) * row_stride
The page indirection now lives entirely in page_offsets, refreshed via
update_page_idx() between iters. The per-iter SRD rebase
(set_bottom_tensor_view_data_ptr + init_raw) and the use_ptr_rebase
overflow heuristic are gone.
Effects:
- The assertion kv_page_size_in_blocks >= 1 (i.e. kPageBlockSize <=
page_size) in the kernel is dropped. Tiles may now span multiple
cache pages, as long as Y0_step_N (= N1*N2 from the K/V tile dist)
divides page_size so that a wave-wide load never straddles a page.
- Pipeline arg renamed kv_page_size_in_blocks -> page_size (PageSize
in tokens). Kernel passes kargs.page_size through directly.
- Validated correctness vs Triton on bf16 / d=64 / decode_s with
block_size in {16, 32, 64}; max abs diff 1.22e-04 in all cases.
Perf is on par with the prior pass-1 scaffolding (~3.6 ms on the
131072-context shape).
TODO(overflow): page_offsets are index_t; caches whose
num_blocks * page_size * row_stride exceeds INT32_MAX will wrap.
A future change should plumb long_index_t through the scatter-gather
load path or compute a per-batch min-page shift in a pre-pass.
TODO(unsupported regime): page_size < Y0_step_N (a wave crosses a page
mid-iter) needs per-lane VGPR SRDs and is not implemented.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -307,7 +307,7 @@ struct UnifiedAttentionKernel
|
||||
const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len);
|
||||
|
||||
index_t _max_seq_prefix_len = amd_wave_read_first_lane(
|
||||
(context_len + q_block_local_idx * kBlockQ + (kBlockM - 1) + 1));
|
||||
(context_len + q_block_local_idx * kBlockQ + (kBlockQ - 1) + 1));
|
||||
|
||||
if(seq_len < _max_seq_prefix_len)
|
||||
{
|
||||
@@ -454,8 +454,12 @@ struct UnifiedAttentionKernel
|
||||
return FmhaMask{cur_batch_query_len, seq_len};
|
||||
}();
|
||||
|
||||
const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize;
|
||||
assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size
|
||||
// Pass-2: the pipeline now uses a unified per-(thread, Y0-iter) page
|
||||
// offset formula and accepts page_size in tokens directly. The earlier
|
||||
// `kPageBlockSize <= page_size` constraint (which required at least one
|
||||
// kernel tile to fit in a cache page) is gone — tiles may span multiple
|
||||
// pages as long as the inner-N step (Y0_step_N from the K/V tile dist)
|
||||
// divides page_size cleanly.
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return UnifiedAttentionPipeline{}(q_dram_window,
|
||||
@@ -465,7 +469,7 @@ struct UnifiedAttentionKernel
|
||||
num_blocks_start,
|
||||
kargs.block_tables_ptr,
|
||||
block_table_offset,
|
||||
kv_page_size_in_blocks,
|
||||
kargs.page_size,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
|
||||
@@ -181,7 +181,7 @@ struct UnifiedAttentionPipeline
|
||||
const index_t num_blocks_start,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
const index_t kv_page_size_in_blocks,
|
||||
const index_t page_size, // PageSize in tokens (cache rows per page)
|
||||
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
@@ -355,62 +355,109 @@ struct UnifiedAttentionPipeline
|
||||
}
|
||||
}
|
||||
|
||||
index_t i_total_loops = num_blocks_start;
|
||||
const index_t PageSize = kv_page_size_in_blocks * kPageBlockSize;
|
||||
index_t i_total_loops = num_blocks_start;
|
||||
const ck_tile::index_t* block_tables_ptr_ =
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_ptr);
|
||||
assert(k_block_idx == v_block_idx); // because of the following line
|
||||
block_table_offset += num_blocks_start;
|
||||
index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx];
|
||||
|
||||
// When row strides are provided, use pointer rebasing to avoid int32 overflow
|
||||
// in tensor_coordinate::get_offset() for large KV pools (>131K blocks for d64/GQA-8).
|
||||
// When strides are 0 (legacy callers), use the original set_window_origin approach.
|
||||
// Use pointer rebasing to avoid int32 overflow in tensor_coordinate for large KV pools.
|
||||
// Only enabled when row strides are provided (from kernel) and for hdim <= 64 configs.
|
||||
// hdim=128 configs have different buffer_view internals that cause issues with rebasing.
|
||||
const bool use_ptr_rebase = (k_row_stride > 0 && v_row_stride > 0 && kHeadDim <= 64);
|
||||
// Pass-2: unified page-offset formula. The kPageBlockSize <= page_size
|
||||
// constraint is gone. For every (thread, Y0-iter) pair we compute:
|
||||
//
|
||||
// logical_token = tile_idx * kPageBlockSize
|
||||
// + thread_N_pos // lane/warp partition
|
||||
// + i * Y0_step_N // per-Y0-iter advance
|
||||
// logical_page = logical_token / page_size // index into block_tables
|
||||
// within_page = logical_token % page_size // row inside the page
|
||||
// phys_page = block_tables[block_table_offset + logical_page]
|
||||
// page_offsets[i] = (phys_page * page_size + within_page) * row_stride
|
||||
//
|
||||
// The page indirection moves entirely into page_offsets, so the per-iter
|
||||
// SRD rebase (set_bottom_tensor_view_data_ptr + init_raw) is dropped —
|
||||
// we just call update_page_idx() to refresh offsets between tiles. This
|
||||
// works for any (kPageBlockSize, page_size) pair where Y0_step_N (= the
|
||||
// inner N stride from the dist encoding, N1 * N2) divides page_size, so
|
||||
// a single wave-wide load instruction never straddles a page boundary.
|
||||
// If page_size < Y0_step_N, per-lane VGPR SRDs would be required and we
|
||||
// don't currently support that.
|
||||
//
|
||||
// TODO(overflow): page_offsets are index_t (int32). For caches whose
|
||||
// num_blocks * page_size * row_stride exceeds INT32_MAX, the offsets
|
||||
// wrap and reads return wrong data. The previous pass had a one-shot
|
||||
// base-pointer shift heuristic for this case (`use_ptr_rebase`); it has
|
||||
// been removed here because it does not interact well with the unified
|
||||
// formula when block_tables are non-monotonic (a far-away page produces
|
||||
// a large negative relative offset that the HW OOB check clamps to 0).
|
||||
// A robust fix would either plumb long_index_t through the gather load
|
||||
// path or compute a per-batch min-page shift in a pre-pass.
|
||||
const auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
|
||||
const auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
using KDstrType = decltype(k_dist);
|
||||
using VDstrType = decltype(v_dist);
|
||||
constexpr index_t KNRepeat =
|
||||
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
|
||||
constexpr index_t VNRepeat =
|
||||
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
|
||||
constexpr index_t KY0_step_N =
|
||||
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] *
|
||||
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}];
|
||||
constexpr index_t VY0_step_N =
|
||||
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] *
|
||||
VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}];
|
||||
|
||||
const auto k_thread_coord = k_dist.calculate_index();
|
||||
const auto v_thread_coord = v_dist.calculate_index();
|
||||
const index_t k_thread_n_pos = k_thread_coord[number<0>{}];
|
||||
const index_t v_thread_n_pos = v_thread_coord[number<0>{}];
|
||||
|
||||
statically_indexed_array<index_t, KNRepeat> k_page_offsets;
|
||||
statically_indexed_array<index_t, VNRepeat> v_page_offsets;
|
||||
|
||||
auto refresh_k_offsets = [&](index_t k_tile_idx) {
|
||||
static_for<0, KNRepeat, 1>{}([&](auto i) {
|
||||
const index_t logical_token = k_tile_idx * kPageBlockSize + k_thread_n_pos +
|
||||
static_cast<index_t>(i.value) * KY0_step_N;
|
||||
const index_t logical_page = logical_token / page_size;
|
||||
const index_t within_page = logical_token - logical_page * page_size;
|
||||
const index_t phys_page =
|
||||
block_tables_ptr_[block_table_offset + logical_page];
|
||||
k_page_offsets(i) =
|
||||
(phys_page * page_size + within_page) * k_row_stride;
|
||||
});
|
||||
};
|
||||
auto refresh_v_offsets = [&](index_t v_tile_idx) {
|
||||
static_for<0, VNRepeat, 1>{}([&](auto i) {
|
||||
const index_t logical_token = v_tile_idx * kPageBlockSize + v_thread_n_pos +
|
||||
static_cast<index_t>(i.value) * VY0_step_N;
|
||||
const index_t logical_page = logical_token / page_size;
|
||||
const index_t within_page = logical_token - logical_page * page_size;
|
||||
const index_t phys_page =
|
||||
block_tables_ptr_[block_table_offset + logical_page];
|
||||
v_page_offsets(i) =
|
||||
(phys_page * page_size + within_page) * v_row_stride;
|
||||
});
|
||||
};
|
||||
|
||||
refresh_k_offsets(k_block_idx);
|
||||
refresh_v_offsets(v_block_idx);
|
||||
|
||||
auto k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
auto v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
auto* k_base_ptr = k_view.buf_.p_data_;
|
||||
auto* v_base_ptr = v_view.buf_.p_data_;
|
||||
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
long_index_t k_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * k_row_stride;
|
||||
k_view.buf_.p_data_ = k_base_ptr + k_off;
|
||||
auto new_k = k_view.buf_.buffer_size_ - k_off;
|
||||
k_view.buf_.buffer_size_ = new_k > 0 ? new_k : kPageBlockSize * kHeadDim;
|
||||
|
||||
long_index_t v_off =
|
||||
static_cast<long_index_t>(kv_blk_idx_initial) * PageSize * v_row_stride;
|
||||
v_view.buf_.p_data_ = v_base_ptr + v_off;
|
||||
auto new_v = v_view.buf_.buffer_size_ - v_off;
|
||||
v_view.buf_.buffer_size_ = new_v > 0 ? new_v : kPageBlockSize * kHeadDim;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Legacy path: use original view with absolute window origin
|
||||
k_view = k_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
v_view = v_dram_block_window_tmp.get_bottom_tensor_view();
|
||||
}
|
||||
|
||||
const index_t init_origin = use_ptr_rebase ? 0 : kv_blk_idx_initial * PageSize;
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_view,
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{init_origin, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
make_tile_scatter_gather(k_view,
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0},
|
||||
k_dist,
|
||||
k_page_offsets);
|
||||
k_dram_window.init_raw();
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_view,
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{init_origin, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
make_tile_scatter_gather(v_view,
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0},
|
||||
v_dist,
|
||||
v_page_offsets);
|
||||
v_dram_window.init_raw();
|
||||
|
||||
// prefetch K tile
|
||||
@@ -508,60 +555,21 @@ struct UnifiedAttentionPipeline
|
||||
// only for block 0 and thread
|
||||
if(blockIdx.x == 0 && threadIdx.x == 0) {}
|
||||
|
||||
auto rebase_window = [](auto& window, auto* base_ptr, long_index_t elem_offset,
|
||||
auto buf_size_orig) {
|
||||
window.bottom_tensor_view_.buf_.p_data_ = base_ptr + elem_offset;
|
||||
auto new_size = buf_size_orig - elem_offset;
|
||||
window.bottom_tensor_view_.buf_.buffer_size_ = new_size > 0 ? new_size : kPageBlockSize * kHeadDim;
|
||||
window.init_raw();
|
||||
window.set_window_origin({0, 0});
|
||||
};
|
||||
|
||||
const auto k_buf_size_orig = k_dram_window.bottom_tensor_view_.buf_.buffer_size_;
|
||||
const auto v_buf_size_orig = v_dram_window.bottom_tensor_view_.buf_.buffer_size_;
|
||||
|
||||
// Pass-2: page indirection lives in page_offsets, not in the SRD. We
|
||||
// refresh the per-iter offsets table and push it to the window via
|
||||
// update_page_idx(); the SRD itself stays put (no init_raw per iter).
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
k_block_idx++;
|
||||
|
||||
index_t k_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)];
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
long_index_t k_row =
|
||||
static_cast<long_index_t>(k_page_blk_idx) * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(k_dram_window, k_base_ptr, k_row * k_row_stride, k_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
k_dram_window.set_window_origin(
|
||||
{k_page_blk_idx * PageSize +
|
||||
(k_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
}
|
||||
refresh_k_offsets(k_block_idx);
|
||||
k_dram_window.update_page_idx(k_page_offsets);
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
v_block_idx++;
|
||||
|
||||
index_t v_page_blk_idx =
|
||||
block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)];
|
||||
if(use_ptr_rebase)
|
||||
{
|
||||
long_index_t v_row =
|
||||
static_cast<long_index_t>(v_page_blk_idx) * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize;
|
||||
rebase_window(v_dram_window, v_base_ptr, v_row * v_row_stride, v_buf_size_orig);
|
||||
}
|
||||
else
|
||||
{
|
||||
v_dram_window.set_window_origin(
|
||||
{v_page_blk_idx * PageSize +
|
||||
(v_block_idx % kv_page_size_in_blocks) * kPageBlockSize,
|
||||
0});
|
||||
}
|
||||
refresh_v_offsets(v_block_idx);
|
||||
v_dram_window.update_page_idx(v_page_offsets);
|
||||
};
|
||||
|
||||
auto K_lds_load = [&](auto k_lds_read_idx) {
|
||||
@@ -1191,7 +1199,7 @@ struct UnifiedAttentionPipeline
|
||||
const index_t num_blocks_start,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
const index_t kv_page_size_in_blocks,
|
||||
const index_t page_size, // PageSize in tokens (cache rows per page)
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
@@ -1210,7 +1218,7 @@ struct UnifiedAttentionPipeline
|
||||
num_blocks_start,
|
||||
block_tables_ptr,
|
||||
block_table_offset,
|
||||
kv_page_size_in_blocks,
|
||||
page_size,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
Reference in New Issue
Block a user