CK-UA: lazy per-issue SRD rebase for the int32-overflow K/V load path

Adds `async_load_raw_lazy_rebase` (+ free-function wrapper
`async_load_tile_raw_lazy_rebase`) to `tile_scatter_gather`, and wires
the unified-attention pipeline's overflow branch to it instead of
`async_load_tile_raw_long`. The fast non-overflow short path is
untouched.

Idea: keep using the cheap `buffer_load_dword_lds` (wave-uniform 4 GB
SRD) for the >4 GB cache pool case, but at each issue check whether the
wave-uniform anchor (lane-0's page offset, extracted via
`amd_wave_read_first_lane`) has drifted outside the current int32
voffset window around `cur_anchor_`. If it has, shift the SRD base
pointer to `p_data_orig_ + wave_anchor`, reinit the buffer resource,
and update `cur_anchor_`. The per-lane voffset is then
`lane_page_offset - cur_anchor_`, which fits in int32 by construction.

State added to `tile_scatter_gather`:
  - `p_data_orig_`      : original SRD base pointer (write-once)
  - `buffer_size_orig_` : original SRD size in elements (write-once)
  - `cur_anchor_`       : current wave-uniform SRD shift (in elements),
                          only ever assigned from
                          amd_wave_read_first_lane, so it stays in SGPRs.

Capture is done by a sister method `init_raw_lazy_rebase()` (used by the
pipeline when `cache_ptr_int32_overflow_possible` is true); on the
non-overflow path the existing `init_raw()` is used so the helper state
is write-never and DCE-eligible.

Correctness precondition: within a single issue every lane of the wave
must map to the same physical page block (WaveSpanInN <= runtime
page_size). Under this precondition the per-lane spread around the
wave-uniform anchor stays inside a half-INT32 element window. When the
precondition does not hold, `async_load_tile_raw_long` is the correct
fallback.

Tested on gfx950 / GPU 2 (no contention), BF16 only:
  * ua-test-scripts/test_unified_attention_ck_correctness.py: 245/245
    BF16/FP16 pass.
  * test_single_shape.py overflow shapes (BF16): correctness passes.

Perf vs `_long` baseline (BF16, overflowing cache, CUDA graph):
  | Shape                              | _long      | _lazy      | delta  |
  | b=1 sq=1 sk=1M d=64 nb=200k        | 2.4149 ms  | 2.7849 ms  | +15.3% |
  | b=8 sq=1 sk=200k d=128 nb=100k     | 1.3762 ms  | 1.4225 ms  |  +3.4% |
  | b=128 sq=1 sk=128k d=128 nb=80k    | 14.0319 ms | 14.4643 ms |  +3.1% |
  | b=32 sq=1 sk=512k d=64 nb=200k     | 7.5211 ms  | 7.5206 ms  |   0.0% |

Verdict: the lazy variant is roughly perf-neutral with `_long` on the
multi-batch decode shapes that dominate real workloads, and ~15% slower
on the single-batch huge-context corner where the rebase rate is
highest. Combined with the WaveSpanInN <= page_size precondition (which
`_long` does not require), `_long` remains the right default. Parked
on a side branch for future experimentation.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-18 09:47:53 +00:00
parent 06e1a70e7a
commit 589fe55d48
3 changed files with 258 additions and 22 deletions

View File

@@ -214,6 +214,34 @@ CK_TILE_DEVICE void async_load_tile_raw_long(LdsTileWindow_&& lds_tile,
bool_constant<pre_nop>{});
}
// Variant of async_load_tile_raw that dispatches to
// async_load_raw_lazy_rebase: the fast buffer_load_dword_lds path, but with
// a wave-uniform SRD base pointer that is lazily re-anchored whenever the
// per-issue page offset would otherwise overflow int32 voffsets. Lifts the
// 4 GB cache-pool limit of the regular async_load_tile_raw without paying
// the per-lane 64-bit base cost of async_load_tile_raw_long. Requires the
// tile_window to have been set up with init_raw_lazy_rebase() and the
// WaveSpanInN <= runtime page_size precondition documented on
// async_load_raw_lazy_rebase. The tile_window is passed by non-const
// reference because the rebase mutates its SRD state.
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void async_load_tile_raw_lazy_rebase(
LdsTileWindow_&& lds_tile,
TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.async_load_raw_lazy_rebase(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");

View File

@@ -719,6 +719,170 @@ struct tile_scatter_gather
});
}
// ------------------------------------------------------------------
// Variant of async_load_raw that lazily re-anchors the wave-uniform SRD
// base pointer so per-lane voffsets stay within int32 range even when
// the total cache pool exceeds 4 GB. For every load issue:
//
// 1. read the per-lane absolute page offset (long_index_t, in
// elements of DataType);
// 2. take lane-0's value as a wave-uniform anchor candidate via
// amd_wave_read_first_lane();
// 3. if (wave_anchor - cur_anchor_) is outside [0, kRebaseThreshold)
// shift the SRD base pointer to p_data_orig_ + wave_anchor and
// reinit the buffer resource; update cur_anchor_ accordingly;
// 4. issue the buffer_load with voffset = (lane_page_offset -
// cur_anchor_), which is guaranteed to fit in int32 (after the
// *sizeof(T) byte scaling inside amd_async_buffer_load_with_oob_raw).
//
// Correctness precondition: within a single issue every lane of the
// wave must map to the same physical page block, i.e.
// WaveSpanInN <= runtime page_size
// Under this precondition the per-lane spread relative to the
// wave-uniform anchor is bounded by page_size * row_stride * sizeof(T),
// which fits comfortably in the half-INT32 element-window we leave
// (kRebaseThreshold below). When the precondition does not hold use
// async_load_raw_long instead.
//
// Fast path (no overflow this issue): one wave-read, one 64-bit
// subtract, one compare-branch. Branch is wave-uniform; rebase rate is
// low so the branch is well predicted by the SIMD scheduler.
//
// This method is non-const because it mutates bottom_tensor_view_
// (rebase) and cur_anchor_ (anchor tracking). Use after
// init_raw_lazy_rebase().
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw_lazy_rebase(
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
m0_set_with_memory(amd_wave_read_first_lane(m0_init_value));
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// The buffer-load builtin scales the element offset by sizeof(DataType)
// and feeds the result to a 32-bit voffset. To keep the byte offset
// within INT32_MAX *for any active lane in the wave*, leave a margin
// of half the element window for per-lane spread relative to lane-0.
constexpr long_index_t kInt32ElemWindow =
static_cast<long_index_t>(INT32_MAX) / static_cast<long_index_t>(sizeof(DataType));
constexpr long_index_t kRebaseThreshold = kInt32ElemWindow / 2;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = get_gather_index(idx_ys_start);
// Per-lane absolute page offset (in elements of DataType).
const long_index_t lane_page_offset =
static_cast<long_index_t>(page_idx_[idx_gather]);
// Wave-uniform anchor candidate: lane-0's value (or first
// active lane). Promoted to SGPRs by the readfirstlane.
const long_index_t wave_anchor = amd_wave_read_first_lane(lane_page_offset);
// Lazy rebase: only when the wave-uniform anchor has drifted
// outside the current int32 voffset window around cur_anchor_.
const long_index_t rel = wave_anchor - cur_anchor_;
if(rel < 0 || rel >= kRebaseThreshold)
{
cur_anchor_ = wave_anchor;
bottom_tensor_view_.buf_.p_data_ = p_data_orig_ + cur_anchor_;
using BufSizeT =
remove_cvref_t<decltype(bottom_tensor_view_.buf_.buffer_size_)>;
bottom_tensor_view_.buf_.buffer_size_ =
static_cast<BufSizeT>(buffer_size_orig_ - cur_anchor_);
bottom_tensor_view_.init_raw();
}
// Per-lane voffset relative to (possibly new) cur_anchor_.
// Fits in int32 by construction (kRebaseThreshold + spread).
const index_t lane_voffset =
static_cast<index_t>(lane_page_offset - cur_anchor_);
// read from bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, lane_voffset, 0, pre_nop_);
}
else
{
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem,
bottom_tensor_thread_coord,
lane_voffset,
valids_[idx_gather],
0,
pre_nop_);
}
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
m0_inc_with_memory(size_per_issue);
}
});
});
}
// TODO: fix with swizzle
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
@@ -1275,6 +1439,21 @@ struct tile_scatter_gather
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// Companion to init_raw(): capture the original SRD base / size so that
// async_load_raw_lazy_rebase() can shift the wave-uniform base pointer
// on demand and later recompute the buffer resource (init_raw) without
// losing the underlying pool layout. Reset the anchor to 0 (no shift).
// Call this once per window instead of init_raw() when the per-issue
// page offsets may exceed INT32_MAX (i.e. when the cache pool size in
// bytes can overflow int32 voffsets).
CK_TILE_HOST_DEVICE void init_raw_lazy_rebase()
{
p_data_orig_ = bottom_tensor_view_.buf_.p_data_;
buffer_size_orig_ = static_cast<long_index_t>(bottom_tensor_view_.buf_.buffer_size_);
cur_anchor_ = 0;
bottom_tensor_view_.init_raw();
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
@@ -1302,6 +1481,20 @@ struct tile_scatter_gather
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord>,
std::byte>
pre_computed_warp_coords_;
// State used by async_load_raw_lazy_rebase(). Populated by
// init_raw_lazy_rebase(); ignored by all other load paths.
// p_data_orig_ : original SRD base pointer (never mutated post-init)
// buffer_size_orig_ : original SRD size in elements of DataType
// cur_anchor_ : current wave-uniform SRD shift (in elements,
// relative to p_data_orig_); kept in SGPRs as the
// value is only ever assigned from
// amd_wave_read_first_lane(...). When non-zero,
// bottom_tensor_view_.buf_.p_data_ ==
// p_data_orig_ + cur_anchor_.
typename BottomTensorView::buffer_view::type* p_data_orig_ = nullptr;
long_index_t buffer_size_orig_ = 0;
long_index_t cur_anchor_ = 0;
};
// TODO: use strategy

View File

@@ -829,7 +829,15 @@ struct UnifiedAttentionPipeline
{0, 0},
k_dist,
k_page_offsets);
k_dram_window.init_raw();
// Use the lazy-rebase-aware init when overflow is possible so the
// rebase path has the original SRD base/size captured. The fast
// path is unaffected: init_raw_lazy_rebase() ends by calling
// init_raw() so the short load path is still valid until the
// first rebase fires.
if(cache_ptr_int32_overflow_possible)
k_dram_window.init_raw_lazy_rebase();
else
k_dram_window.init_raw();
auto v_dram_window =
make_tile_scatter_gather(v_view,
@@ -837,7 +845,10 @@ struct UnifiedAttentionPipeline
{0, 0},
v_dist,
v_page_offsets);
v_dram_window.init_raw();
if(cache_ptr_int32_overflow_possible)
v_dram_window.init_raw_lazy_rebase();
else
v_dram_window.init_raw();
// prefetch K tile
constexpr index_t k0_loops = 1;
@@ -940,27 +951,29 @@ struct UnifiedAttentionPipeline
//
// Two load paths, dispatched on the runtime overflow flag:
// - false: `async_load_tile_raw` → `buffer_load_dword_lds` with a
// wave-uniform 4 GB-capped SRD. Faster, but per-lane voffsets
// are int32 so the path is only correct while
// wave-uniform 4 GB-capped SRD. Fastest path; only correct when
// `num_blocks * page_size * row_stride * sizeof(T) ≤ INT32_MAX`.
// - true: `async_load_tile_raw_long` → `global_load_lds_dwordx*`
// with per-lane 64-bit base pointers, lifting the 4 GB limit
// at the cost of lower throughput.
// - true: `async_load_tile_raw_lazy_rebase` → still
// `buffer_load_dword_lds`, but with a wave-uniform SRD base
// pointer that is lazily re-anchored at each issue whenever
// the per-lane page offset would otherwise overflow the int32
// voffset. Lifts the 4 GB cache-pool limit without paying the
// per-lane 64-bit base cost of `async_load_tile_raw_long`.
// Precondition: WaveSpanInN ≤ runtime page_size (so within a
// single issue every lane of the wave maps to the same physical
// page block and the per-lane spread relative to the
// wave-uniform anchor stays inside a half-INT32 element window).
// If the precondition fails, swap this back to
// `async_load_tile_raw_long` (per-lane 64-bit `global_load_lds`).
// The branch is on a wave-uniform value, so no execution divergence.
//
// We tried a third "per-issue SRD rebase" path
// (`async_load_tile_raw_rebased`: buffer_load_dword_lds with a
// per-issue SRD whose 48-bit base absorbs the wave-uniform page
// offset, valid when WaveSpanInN ≤ runtime page_size). It was
// correct on every big-cache decode shape tested but came out at
// best tied with the long path and at worst ~6% slower (e.g.
// b=1 sk=1M d=64: 2.46 ms vs 2.32 ms; b=8 sk=200k d=128: 2.12 ms
// vs 2.02 ms — see git log for the full table). These workloads
// are compute / softmax bound, not K/V-load bandwidth bound, so
// the buffer_load vs global_load_lds throughput edge never
// materialises, while per-issue SRD construction adds real SGPR
// pressure. The rebased helper has been removed to keep the
// dispatch (and emitted kernel size) minimal.
// History: an earlier "per-issue SRD rebase" path (rebase on every
// issue regardless of whether overflow was imminent) was tested and
// came out at best tied with the long path and at worst ~6% slower
// because per-issue SRD construction adds real SGPR pressure on
// compute/softmax-bound shapes. The current `_lazy_rebase` only
// rebases when the wave anchor drifts outside the current int32
// voffset window, keeping the fast path register-cheap.
constexpr index_t KWaveSpanInN =
(KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] - 1) *
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}] +
@@ -969,7 +982,8 @@ struct UnifiedAttentionPipeline
auto K_mem_load = [&](auto k_lds_write_idx) {
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
async_load_tile_raw_lazy_rebase(k_lds_window_store(k_lds_write_idx),
k_dram_window);
else
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
k_block_idx++;
@@ -979,7 +993,8 @@ struct UnifiedAttentionPipeline
auto V_mem_load = [&](auto v_lds_write_idx) {
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
async_load_tile_raw_lazy_rebase(v_lds_window_store(v_lds_write_idx),
v_dram_window);
else
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
v_block_idx++;