mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: lazy per-issue SRD rebase for the int32-overflow K/V load path
Adds `async_load_raw_lazy_rebase` (+ free-function wrapper
`async_load_tile_raw_lazy_rebase`) to `tile_scatter_gather`, and wires
the unified-attention pipeline's overflow branch to it instead of
`async_load_tile_raw_long`. The fast non-overflow short path is
untouched.
Idea: keep using the cheap `buffer_load_dword_lds` (wave-uniform 4 GB
SRD) for the >4 GB cache pool case, but at each issue check whether the
wave-uniform anchor (lane-0's page offset, extracted via
`amd_wave_read_first_lane`) has drifted outside the current int32
voffset window around `cur_anchor_`. If it has, shift the SRD base
pointer to `p_data_orig_ + wave_anchor`, reinit the buffer resource,
and update `cur_anchor_`. The per-lane voffset is then
`lane_page_offset - cur_anchor_`, which fits in int32 by construction.
State added to `tile_scatter_gather`:
- `p_data_orig_` : original SRD base pointer (write-once)
- `buffer_size_orig_` : original SRD size in elements (write-once)
- `cur_anchor_` : current wave-uniform SRD shift (in elements),
only ever assigned from
amd_wave_read_first_lane, so it stays in SGPRs.
Capture is done by a sister method `init_raw_lazy_rebase()` (used by the
pipeline when `cache_ptr_int32_overflow_possible` is true); on the
non-overflow path the existing `init_raw()` is used so the helper state
is write-never and DCE-eligible.
Correctness precondition: within a single issue every lane of the wave
must map to the same physical page block (WaveSpanInN <= runtime
page_size). Under this precondition the per-lane spread around the
wave-uniform anchor stays inside a half-INT32 element window. When the
precondition does not hold, `async_load_tile_raw_long` is the correct
fallback.
Tested on gfx950 / GPU 2 (no contention), BF16 only:
* ua-test-scripts/test_unified_attention_ck_correctness.py: 245/245
BF16/FP16 pass.
* test_single_shape.py overflow shapes (BF16): correctness passes.
Perf vs `_long` baseline (BF16, overflowing cache, CUDA graph):
| Shape | _long | _lazy | delta |
| b=1 sq=1 sk=1M d=64 nb=200k | 2.4149 ms | 2.7849 ms | +15.3% |
| b=8 sq=1 sk=200k d=128 nb=100k | 1.3762 ms | 1.4225 ms | +3.4% |
| b=128 sq=1 sk=128k d=128 nb=80k | 14.0319 ms | 14.4643 ms | +3.1% |
| b=32 sq=1 sk=512k d=64 nb=200k | 7.5211 ms | 7.5206 ms | 0.0% |
Verdict: the lazy variant is roughly perf-neutral with `_long` on the
multi-batch decode shapes that dominate real workloads, and ~15% slower
on the single-batch huge-context corner where the rebase rate is
highest. Combined with the WaveSpanInN <= page_size precondition (which
`_long` does not require), `_long` remains the right default. Parked
on a side branch for future experimentation.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -214,6 +214,34 @@ CK_TILE_DEVICE void async_load_tile_raw_long(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// Variant of async_load_tile_raw that dispatches to
|
||||
// async_load_raw_lazy_rebase: the fast buffer_load_dword_lds path, but with
|
||||
// a wave-uniform SRD base pointer that is lazily re-anchored whenever the
|
||||
// per-issue page offset would otherwise overflow int32 voffsets. Lifts the
|
||||
// 4 GB cache-pool limit of the regular async_load_tile_raw without paying
|
||||
// the per-lane 64-bit base cost of async_load_tile_raw_long. Requires the
|
||||
// tile_window to have been set up with init_raw_lazy_rebase() and the
|
||||
// WaveSpanInN <= runtime page_size precondition documented on
|
||||
// async_load_raw_lazy_rebase. The tile_window is passed by non-const
|
||||
// reference because the rebase mutates its SRD state.
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_load_tile_raw_lazy_rebase(
|
||||
LdsTileWindow_&& lds_tile,
|
||||
TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.async_load_raw_lazy_rebase(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
|
||||
@@ -719,6 +719,170 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Variant of async_load_raw that lazily re-anchors the wave-uniform SRD
|
||||
// base pointer so per-lane voffsets stay within int32 range even when
|
||||
// the total cache pool exceeds 4 GB. For every load issue:
|
||||
//
|
||||
// 1. read the per-lane absolute page offset (long_index_t, in
|
||||
// elements of DataType);
|
||||
// 2. take lane-0's value as a wave-uniform anchor candidate via
|
||||
// amd_wave_read_first_lane();
|
||||
// 3. if (wave_anchor - cur_anchor_) is outside [0, kRebaseThreshold)
|
||||
// shift the SRD base pointer to p_data_orig_ + wave_anchor and
|
||||
// reinit the buffer resource; update cur_anchor_ accordingly;
|
||||
// 4. issue the buffer_load with voffset = (lane_page_offset -
|
||||
// cur_anchor_), which is guaranteed to fit in int32 (after the
|
||||
// *sizeof(T) byte scaling inside amd_async_buffer_load_with_oob_raw).
|
||||
//
|
||||
// Correctness precondition: within a single issue every lane of the
|
||||
// wave must map to the same physical page block, i.e.
|
||||
// WaveSpanInN <= runtime page_size
|
||||
// Under this precondition the per-lane spread relative to the
|
||||
// wave-uniform anchor is bounded by page_size * row_stride * sizeof(T),
|
||||
// which fits comfortably in the half-INT32 element-window we leave
|
||||
// (kRebaseThreshold below). When the precondition does not hold use
|
||||
// async_load_raw_long instead.
|
||||
//
|
||||
// Fast path (no overflow this issue): one wave-read, one 64-bit
|
||||
// subtract, one compare-branch. Branch is wave-uniform; rebase rate is
|
||||
// low so the branch is well predicted by the SIMD scheduler.
|
||||
//
|
||||
// This method is non-const because it mutates bottom_tensor_view_
|
||||
// (rebase) and cur_anchor_ (anchor tracking). Use after
|
||||
// init_raw_lazy_rebase().
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw_lazy_rebase(
|
||||
LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(amd_wave_read_first_lane(m0_init_value));
|
||||
|
||||
using Traits = load_store_traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
// The buffer-load builtin scales the element offset by sizeof(DataType)
|
||||
// and feeds the result to a 32-bit voffset. To keep the byte offset
|
||||
// within INT32_MAX *for any active lane in the wave*, leave a margin
|
||||
// of half the element window for per-lane spread relative to lane-0.
|
||||
constexpr long_index_t kInt32ElemWindow =
|
||||
static_cast<long_index_t>(INT32_MAX) / static_cast<long_index_t>(sizeof(DataType));
|
||||
constexpr long_index_t kRebaseThreshold = kInt32ElemWindow / 2;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
constexpr auto pre_nop_ = [&]() {
|
||||
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
|
||||
// Per-lane absolute page offset (in elements of DataType).
|
||||
const long_index_t lane_page_offset =
|
||||
static_cast<long_index_t>(page_idx_[idx_gather]);
|
||||
|
||||
// Wave-uniform anchor candidate: lane-0's value (or first
|
||||
// active lane). Promoted to SGPRs by the readfirstlane.
|
||||
const long_index_t wave_anchor = amd_wave_read_first_lane(lane_page_offset);
|
||||
|
||||
// Lazy rebase: only when the wave-uniform anchor has drifted
|
||||
// outside the current int32 voffset window around cur_anchor_.
|
||||
const long_index_t rel = wave_anchor - cur_anchor_;
|
||||
if(rel < 0 || rel >= kRebaseThreshold)
|
||||
{
|
||||
cur_anchor_ = wave_anchor;
|
||||
bottom_tensor_view_.buf_.p_data_ = p_data_orig_ + cur_anchor_;
|
||||
using BufSizeT =
|
||||
remove_cvref_t<decltype(bottom_tensor_view_.buf_.buffer_size_)>;
|
||||
bottom_tensor_view_.buf_.buffer_size_ =
|
||||
static_cast<BufSizeT>(buffer_size_orig_ - cur_anchor_);
|
||||
bottom_tensor_view_.init_raw();
|
||||
}
|
||||
|
||||
// Per-lane voffset relative to (possibly new) cur_anchor_.
|
||||
// Fits in int32 by construction (kRebaseThreshold + spread).
|
||||
const index_t lane_voffset =
|
||||
static_cast<index_t>(lane_page_offset - cur_anchor_);
|
||||
|
||||
// read from bottom tensor
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, lane_voffset, 0, pre_nop_);
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
lane_voffset,
|
||||
valids_[idx_gather],
|
||||
0,
|
||||
pre_nop_);
|
||||
}
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: fix with swizzle
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
@@ -1275,6 +1439,21 @@ struct tile_scatter_gather
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// Companion to init_raw(): capture the original SRD base / size so that
|
||||
// async_load_raw_lazy_rebase() can shift the wave-uniform base pointer
|
||||
// on demand and later recompute the buffer resource (init_raw) without
|
||||
// losing the underlying pool layout. Reset the anchor to 0 (no shift).
|
||||
// Call this once per window instead of init_raw() when the per-issue
|
||||
// page offsets may exceed INT32_MAX (i.e. when the cache pool size in
|
||||
// bytes can overflow int32 voffsets).
|
||||
CK_TILE_HOST_DEVICE void init_raw_lazy_rebase()
|
||||
{
|
||||
p_data_orig_ = bottom_tensor_view_.buf_.p_data_;
|
||||
buffer_size_orig_ = static_cast<long_index_t>(bottom_tensor_view_.buf_.buffer_size_);
|
||||
cur_anchor_ = 0;
|
||||
bottom_tensor_view_.init_raw();
|
||||
}
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
@@ -1302,6 +1481,20 @@ struct tile_scatter_gather
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord>,
|
||||
std::byte>
|
||||
pre_computed_warp_coords_;
|
||||
|
||||
// State used by async_load_raw_lazy_rebase(). Populated by
|
||||
// init_raw_lazy_rebase(); ignored by all other load paths.
|
||||
// p_data_orig_ : original SRD base pointer (never mutated post-init)
|
||||
// buffer_size_orig_ : original SRD size in elements of DataType
|
||||
// cur_anchor_ : current wave-uniform SRD shift (in elements,
|
||||
// relative to p_data_orig_); kept in SGPRs as the
|
||||
// value is only ever assigned from
|
||||
// amd_wave_read_first_lane(...). When non-zero,
|
||||
// bottom_tensor_view_.buf_.p_data_ ==
|
||||
// p_data_orig_ + cur_anchor_.
|
||||
typename BottomTensorView::buffer_view::type* p_data_orig_ = nullptr;
|
||||
long_index_t buffer_size_orig_ = 0;
|
||||
long_index_t cur_anchor_ = 0;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
|
||||
@@ -829,7 +829,15 @@ struct UnifiedAttentionPipeline
|
||||
{0, 0},
|
||||
k_dist,
|
||||
k_page_offsets);
|
||||
k_dram_window.init_raw();
|
||||
// Use the lazy-rebase-aware init when overflow is possible so the
|
||||
// rebase path has the original SRD base/size captured. The fast
|
||||
// path is unaffected: init_raw_lazy_rebase() ends by calling
|
||||
// init_raw() so the short load path is still valid until the
|
||||
// first rebase fires.
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
k_dram_window.init_raw_lazy_rebase();
|
||||
else
|
||||
k_dram_window.init_raw();
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_view,
|
||||
@@ -837,7 +845,10 @@ struct UnifiedAttentionPipeline
|
||||
{0, 0},
|
||||
v_dist,
|
||||
v_page_offsets);
|
||||
v_dram_window.init_raw();
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
v_dram_window.init_raw_lazy_rebase();
|
||||
else
|
||||
v_dram_window.init_raw();
|
||||
|
||||
// prefetch K tile
|
||||
constexpr index_t k0_loops = 1;
|
||||
@@ -940,27 +951,29 @@ struct UnifiedAttentionPipeline
|
||||
//
|
||||
// Two load paths, dispatched on the runtime overflow flag:
|
||||
// - false: `async_load_tile_raw` → `buffer_load_dword_lds` with a
|
||||
// wave-uniform 4 GB-capped SRD. Faster, but per-lane voffsets
|
||||
// are int32 so the path is only correct while
|
||||
// wave-uniform 4 GB-capped SRD. Fastest path; only correct when
|
||||
// `num_blocks * page_size * row_stride * sizeof(T) ≤ INT32_MAX`.
|
||||
// - true: `async_load_tile_raw_long` → `global_load_lds_dwordx*`
|
||||
// with per-lane 64-bit base pointers, lifting the 4 GB limit
|
||||
// at the cost of lower throughput.
|
||||
// - true: `async_load_tile_raw_lazy_rebase` → still
|
||||
// `buffer_load_dword_lds`, but with a wave-uniform SRD base
|
||||
// pointer that is lazily re-anchored at each issue whenever
|
||||
// the per-lane page offset would otherwise overflow the int32
|
||||
// voffset. Lifts the 4 GB cache-pool limit without paying the
|
||||
// per-lane 64-bit base cost of `async_load_tile_raw_long`.
|
||||
// Precondition: WaveSpanInN ≤ runtime page_size (so within a
|
||||
// single issue every lane of the wave maps to the same physical
|
||||
// page block and the per-lane spread relative to the
|
||||
// wave-uniform anchor stays inside a half-INT32 element window).
|
||||
// If the precondition fails, swap this back to
|
||||
// `async_load_tile_raw_long` (per-lane 64-bit `global_load_lds`).
|
||||
// The branch is on a wave-uniform value, so no execution divergence.
|
||||
//
|
||||
// We tried a third "per-issue SRD rebase" path
|
||||
// (`async_load_tile_raw_rebased`: buffer_load_dword_lds with a
|
||||
// per-issue SRD whose 48-bit base absorbs the wave-uniform page
|
||||
// offset, valid when WaveSpanInN ≤ runtime page_size). It was
|
||||
// correct on every big-cache decode shape tested but came out at
|
||||
// best tied with the long path and at worst ~6% slower (e.g.
|
||||
// b=1 sk=1M d=64: 2.46 ms vs 2.32 ms; b=8 sk=200k d=128: 2.12 ms
|
||||
// vs 2.02 ms — see git log for the full table). These workloads
|
||||
// are compute / softmax bound, not K/V-load bandwidth bound, so
|
||||
// the buffer_load vs global_load_lds throughput edge never
|
||||
// materialises, while per-issue SRD construction adds real SGPR
|
||||
// pressure. The rebased helper has been removed to keep the
|
||||
// dispatch (and emitted kernel size) minimal.
|
||||
// History: an earlier "per-issue SRD rebase" path (rebase on every
|
||||
// issue regardless of whether overflow was imminent) was tested and
|
||||
// came out at best tied with the long path and at worst ~6% slower
|
||||
// because per-issue SRD construction adds real SGPR pressure on
|
||||
// compute/softmax-bound shapes. The current `_lazy_rebase` only
|
||||
// rebases when the wave anchor drifts outside the current int32
|
||||
// voffset window, keeping the fast path register-cheap.
|
||||
constexpr index_t KWaveSpanInN =
|
||||
(KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] - 1) *
|
||||
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}] +
|
||||
@@ -969,7 +982,8 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
async_load_tile_raw_lazy_rebase(k_lds_window_store(k_lds_write_idx),
|
||||
k_dram_window);
|
||||
else
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
k_block_idx++;
|
||||
@@ -979,7 +993,8 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
async_load_tile_raw_lazy_rebase(v_lds_window_store(v_lds_write_idx),
|
||||
v_dram_window);
|
||||
else
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
v_block_idx++;
|
||||
|
||||
Reference in New Issue
Block a user