mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Adds `async_load_raw_lazy_rebase` (+ free-function wrapper
`async_load_tile_raw_lazy_rebase`) to `tile_scatter_gather`, and wires
the unified-attention pipeline's overflow branch to it instead of
`async_load_tile_raw_long`. The fast non-overflow short path is
untouched.
Idea: keep using the cheap `buffer_load_dword_lds` (wave-uniform 4 GB
SRD) for the >4 GB cache pool case, but at each issue check whether the
wave-uniform anchor (lane-0's page offset, extracted via
`amd_wave_read_first_lane`) has drifted outside the current int32
voffset window around `cur_anchor_`. If it has, shift the SRD base
pointer to `p_data_orig_ + wave_anchor`, reinit the buffer resource,
and update `cur_anchor_`. The per-lane voffset is then
`lane_page_offset - cur_anchor_`, which fits in int32 by construction.
State added to `tile_scatter_gather`:
- `p_data_orig_` : original SRD base pointer (write-once)
- `buffer_size_orig_` : original SRD size in elements (write-once)
- `cur_anchor_` : current wave-uniform SRD shift (in elements),
only ever assigned from
amd_wave_read_first_lane, so it stays in SGPRs.
Capture is done by a sister method `init_raw_lazy_rebase()` (used by the
pipeline when `cache_ptr_int32_overflow_possible` is true); on the
non-overflow path the existing `init_raw()` is used so the helper state
is write-never and DCE-eligible.
Correctness precondition: within a single issue every lane of the wave
must map to the same physical page block (WaveSpanInN <= runtime
page_size). Under this precondition the per-lane spread around the
wave-uniform anchor stays inside a half-INT32 element window. When the
precondition does not hold, `async_load_tile_raw_long` is the correct
fallback.
Tested on gfx950 / GPU 2 (no contention), BF16 only:
* ua-test-scripts/test_unified_attention_ck_correctness.py: 245/245
BF16/FP16 pass.
* test_single_shape.py overflow shapes (BF16): correctness passes.
Perf vs `_long` baseline (BF16, overflowing cache, CUDA graph):
| Shape | _long | _lazy | delta |
| b=1 sq=1 sk=1M d=64 nb=200k | 2.4149 ms | 2.7849 ms | +15.3% |
| b=8 sq=1 sk=200k d=128 nb=100k | 1.3762 ms | 1.4225 ms | +3.4% |
| b=128 sq=1 sk=128k d=128 nb=80k | 14.0319 ms | 14.4643 ms | +3.1% |
| b=32 sq=1 sk=512k d=64 nb=200k | 7.5211 ms | 7.5206 ms | 0.0% |
Verdict: the lazy variant is roughly perf-neutral with `_long` on the
multi-batch decode shapes that dominate real workloads, and ~15% slower
on the single-batch huge-context corner where the rebase rate is
highest. Combined with the WaveSpanInN <= page_size precondition (which
`_long` does not require), `_long` remains the right default. Parked
on a side branch for future experimentation.
Co-authored-by: Cursor <cursoragent@cursor.com>