Files
composable_kernel/include
juuso-oskari 589fe55d48 CK-UA: lazy per-issue SRD rebase for the int32-overflow K/V load path
Adds `async_load_raw_lazy_rebase` (+ free-function wrapper
`async_load_tile_raw_lazy_rebase`) to `tile_scatter_gather`, and wires
the unified-attention pipeline's overflow branch to it instead of
`async_load_tile_raw_long`. The fast non-overflow short path is
untouched.

Idea: keep using the cheap `buffer_load_dword_lds` (wave-uniform 4 GB
SRD) for the >4 GB cache pool case, but at each issue check whether the
wave-uniform anchor (lane-0's page offset, extracted via
`amd_wave_read_first_lane`) has drifted outside the current int32
voffset window around `cur_anchor_`. If it has, shift the SRD base
pointer to `p_data_orig_ + wave_anchor`, reinit the buffer resource,
and update `cur_anchor_`. The per-lane voffset is then
`lane_page_offset - cur_anchor_`, which fits in int32 by construction.

State added to `tile_scatter_gather`:
  - `p_data_orig_`      : original SRD base pointer (write-once)
  - `buffer_size_orig_` : original SRD size in elements (write-once)
  - `cur_anchor_`       : current wave-uniform SRD shift (in elements),
                          only ever assigned from
                          amd_wave_read_first_lane, so it stays in SGPRs.

Capture is done by a sister method `init_raw_lazy_rebase()` (used by the
pipeline when `cache_ptr_int32_overflow_possible` is true); on the
non-overflow path the existing `init_raw()` is used so the helper state
is write-never and DCE-eligible.

Correctness precondition: within a single issue every lane of the wave
must map to the same physical page block (WaveSpanInN <= runtime
page_size). Under this precondition the per-lane spread around the
wave-uniform anchor stays inside a half-INT32 element window. When the
precondition does not hold, `async_load_tile_raw_long` is the correct
fallback.

Tested on gfx950 / GPU 2 (no contention), BF16 only:
  * ua-test-scripts/test_unified_attention_ck_correctness.py: 245/245
    BF16/FP16 pass.
  * test_single_shape.py overflow shapes (BF16): correctness passes.

Perf vs `_long` baseline (BF16, overflowing cache, CUDA graph):
  | Shape                              | _long      | _lazy      | delta  |
  | b=1 sq=1 sk=1M d=64 nb=200k        | 2.4149 ms  | 2.7849 ms  | +15.3% |
  | b=8 sq=1 sk=200k d=128 nb=100k     | 1.3762 ms  | 1.4225 ms  |  +3.4% |
  | b=128 sq=1 sk=128k d=128 nb=80k    | 14.0319 ms | 14.4643 ms |  +3.1% |
  | b=32 sq=1 sk=512k d=64 nb=200k     | 7.5211 ms  | 7.5206 ms  |   0.0% |

Verdict: the lazy variant is roughly perf-neutral with `_long` on the
multi-batch decode shapes that dominate real workloads, and ~15% slower
on the single-batch huge-context corner where the rebase rate is
highest. Combined with the WaveSpanInN <= page_size precondition (which
`_long` does not require), `_long` remains the right default. Parked
on a side branch for future experimentation.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-20 10:24:12 +00:00
..