CK-UA: dispatch K/V async load on cache_ptr_int32_overflow_possible

The shared-SRD buffer_load_dword_lds path that K_mem_load / V_mem_load use
wraps the per-lane voffset (int32 bytes) once
  num_blocks * page_size * row_stride * sizeof(T) > INT32_MAX,
silently returning wrong data on large paged-KV pools (e.g. >4 GB caches).

Add a second path, async_load_tile_raw_long, that issues the same load via
__builtin_amdgcn_global_load_lds with per-lane 64-bit base pointers, lifting
both 4 GB limits (SRD size + voffset). Per-issue LDS pointers are computed
explicitly because the intrinsic sets m0 itself, so the old m0_set / m0_inc
bookkeeping doesn't apply. The path also clamps lane_elem_off to the live
buffer range to mimic the original SRD's hardware OOB behaviour.

Dispatch is a wave-uniform runtime branch on a new
cache_ptr_int32_overflow_possible flag plumbed from
unified_attention_args through MakeKargs into the pipeline operator().
Small caches keep the original buffer_load throughput; only the (rare)
>4 GB cache pays the global_load_lds cost.

k_page_offsets / v_page_offsets are widened to long_index_t. The original
buffer_load path implicitly narrows back to int32 when forwarding through
async_get_vectorized_elements_raw, which is intentional and safe whenever
the overflow flag is false.

For diagnostics, also derive a constexpr KWaveSpanInN =
(LaneGroups - 1) * NumWarps + 1 inside the pipeline; when this exceeds
page_size a single buffer_load spans multiple random pages, so the
per-issue SRD-rebase optimisation (not implemented yet) would not apply
even on a sub-4 GB cache. Informational only today.

Test: ua-test-scripts correctness sweep (245/245 pass), plus
  test_single_shape.py -b 32 -sq 8192 -sk 120000 -hq 64 -hk 8 -d 64 \
      --num-blocks 1200000 --block-size 16 --test
which previously returned wrong data due to the int32 wrap and now passes
with max abs diff 1.22e-04 vs Triton.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-15 09:00:43 +00:00
parent d77f0bea63
commit 1f69421434
8 changed files with 371 additions and 13 deletions

View File

@@ -68,6 +68,14 @@ struct unified_attention_args
index_t num_seqs; // number of batches for q
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
// Set to true when the K/V cache is large enough that an int32 byte
// offset into it can overflow (i.e. when
// num_blocks * page_size * num_kv_heads * head_dim * sizeof(T) > INT32_MAX
// ). When true, the pipeline routes K/V async loads through
// `global_load_lds` (per-lane 64-bit base ptr); when false, it uses the
// faster `buffer_load_dword_lds` path with a shared 4 GB-capped SRD.
bool cache_ptr_int32_overflow_possible = false;
// KV-segment parallelism (split-KV). When num_splits == 1, the kernel
// writes to o_ptr as usual. When num_splits > 1, the kernel is launched
// with a 3D grid whose z-dim is num_splits — each CTA computes its own

View File

@@ -328,7 +328,8 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc);
args.nhead_stride_o_acc,
args.cache_ptr_int32_overflow_possible);
dim3 grids;
if constexpr(UseDecodeGrid)

View File

@@ -2861,6 +2861,65 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
bool_constant<pre_nop>{});
}
// =============================================================================
// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer.
//
// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD
// (`int32x4_t` resource descriptor) entirely:
// - SRD's `size` field is uint32_t (max ~4 GB pool). Caches above that wrap.
// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap.
// Replacing the underlying HW instruction with `global_load_lds` (per-lane
// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits.
// Required for paged-KV caches whose `num_blocks * page_size * row_stride *
// sizeof(T)` exceeds INT32_MAX (e.g. very-long-context decode pools).
//
// Caveats:
// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer
// is valid (in our pipeline use, the page_table lookup guarantees this).
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`.
// Older arches would need a `global_load + ds_write` fallback.
// =============================================================================
template <typename T,
index_t N,
index_t byte_offset_imm = 0, // 13-bit signed
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
const T* base_ptr_64,
bool_constant<pre_nop> = {})
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"global_load_lds: only dword / dwordx3 / dwordx4 supported on gfx950");
static_assert(-4096 <= byte_offset_imm && byte_offset_imm <= 4095,
"global_load_lds: byte_offset_imm must fit in 13-bit signed");
// C-style cast injects the address-space attribute the intrinsic expects
// (addrspace(1) for global, addrspace(3) for LDS) without losing const.
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
const __attribute__((address_space(1))) void* gptr =
(const __attribute__((address_space(1))) void*)base_ptr_64;
__attribute__((address_space(3))) void* lptr =
(__attribute__((address_space(3))) void*)smem;
#pragma clang diagnostic pop
if constexpr(pre_nop)
asm volatile("s_nop 4\n" ::: "memory");
// Front-end requires `size`, `offset` and `aux` to be ImmArg / integer
// literals. A switch on the constexpr `bytes` value lets each branch
// pass the literal directly.
constexpr int kCoherence = static_cast<int>(coherence);
if constexpr(bytes == 16)
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
else if constexpr(bytes == 12)
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
else /* bytes == 4 */
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
}
// This version support buffer resource as input arg
template <typename T,
index_t N,

View File

@@ -2686,6 +2686,67 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
bool_constant<pre_nop>{});
}
// =============================================================================
// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer.
//
// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD
// (`int32x4_t` / `__amdgpu_buffer_rsrc_t` resource descriptor) entirely. The
// buffer_load path has two 32-bit limits at the HW boundary:
// - SRD `size` field is uint32_t (max ~4 GB pool). Caches above that wrap.
// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap.
// Replacing the underlying HW instruction with `global_load_lds` (per-lane
// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits.
// Required for paged-KV caches whose
// `num_blocks * page_size * row_stride * sizeof(T)` exceeds INT32_MAX (e.g.
// very-long-context decode pools).
//
// Caveats:
// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer
// is valid (in our pipeline use, the page_table lookup guarantees this).
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`. Older
// arches would need a `global_load + ds_write` fallback.
// =============================================================================
template <typename T,
index_t N,
index_t byte_offset_imm = 0, // 13-bit signed
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
const T* base_ptr_64,
bool_constant<pre_nop> = {})
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"global_load_lds: only dword / dwordx3 / dwordx4 supported on gfx950");
static_assert(-4096 <= byte_offset_imm && byte_offset_imm <= 4095,
"global_load_lds: byte_offset_imm must fit in 13-bit signed");
// C-style cast injects the address-space attribute the intrinsic expects
// (addrspace(1) for global, addrspace(3) for LDS) without losing const.
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
const __attribute__((address_space(1))) void* gptr =
(const __attribute__((address_space(1))) void*)base_ptr_64;
__attribute__((address_space(3))) void* lptr =
(__attribute__((address_space(3))) void*)smem;
#pragma clang diagnostic pop
if constexpr(pre_nop)
asm volatile("s_nop 4\n" ::: "memory");
// Front-end requires `size`, `offset` and `aux` to be ImmArg / integer
// literals. A switch on the constexpr `bytes` value lets each branch
// pass the literal directly.
constexpr int kCoherence = static_cast<int>(coherence);
if constexpr(bytes == 16)
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
else if constexpr(bytes == 12)
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
else /* bytes == 4 */
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
}
// This version support buffer resource as input arg
template <typename T,
index_t N,

View File

@@ -193,6 +193,27 @@ CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile,
bool_constant<pre_nop>{});
}
// Variant of async_load_tile_raw that dispatches to async_load_raw_long
// (global_load_lds path with per-lane 64-bit base pointers). Only valid for
// tile_scatter_gather windows whose PageIdxArray element type supports
// 64-bit values (e.g. long_index_t).
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void async_load_tile_raw_long(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.async_load_raw_long(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
@@ -1044,6 +1045,156 @@ struct tile_scatter_gather
}
}
// ---------------------------------------------------------------------
// async_load_raw_long: variant of async_load_raw that issues the per-tile
// gather load via `amd_async_global_load_lds_raw` (i.e. AMDGCN
// `global_load_lds_dwordx*`) rather than `buffer_load_dword_lds`.
//
// Identical iteration structure, m0/LDS-slot bookkeeping, and SFC walk
// as async_load_raw — only the HW load instruction is swapped. The page
// indirection is folded into a per-lane 64-bit base pointer, lifting
// both 4 GB limits in the buffer_load path (SRD `size` field is uint32_t,
// per-lane voffset is int32). PageIdxArray's element type can therefore
// be `long_index_t` (caller's responsibility).
//
// Why not per-issue SRD rebase? In the K/V tile distributions emitted
// by CK-UA today, a single wave-wide buffer_load_dword* spans
// (LaneGroups) different N-positions, which for the prefill configs
// (NumWarps≥2) can map to several different pages within one issue.
// For paged-KV caches > 4 GB, those pages can be ≫ 4 GB apart in the
// global K buffer, exceeding the 32-bit voffset / 32-bit SRD-size
// range. Only a per-lane 64-bit base pointer (i.e. global_load_lds)
// can address all those lanes from a single instruction.
//
// OOB note: this path drops the SRD's hardware OOB clamp. Caller must
// ensure `page_idx_` only references live pages (true in the paged-KV
// use-case where block_tables are populated from a valid allocator).
// ---------------------------------------------------------------------
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw_long(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// The per-tile LDS layout in elements (not bytes). The new
// `global_load_lds_*` path differs from the `buffer_load_dword_lds`
// path here: the LLVM intrinsic implicitly sets `m0` from its
// `lptr` argument every call, so the manual `m0_set / m0_inc`
// bookkeeping used by `async_load_raw` would be silently
// overwritten. Instead, we compute the per-issue LDS element offset
// and add it to the LDS base pointer on each call — the compiler
// emits a fresh `s_mov_b32 m0, ...` per load with the right value.
const index_t elems_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
const index_t elems_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
elems_per_buf;
const index_t elems_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
elems_per_buf;
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// bf16/fp16/etc. element-typed global ptr base for this tile-window.
const DataType* base_data_ptr = bottom_tensor_view_.get_buffer_view().p_data_;
// Element count in the underlying buffer — used to clamp per-lane
// pointers that go past the live range, mimicking the SRD's OOB
// semantics on the original `buffer_load_dword_lds` path.
const long_index_t buf_elems = static_cast<long_index_t>(
bottom_tensor_view_.get_buffer_view().buffer_size_);
LdsDataType* lds_base = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
// Wave / warp-group offset into LDS, computed once.
const index_t lds_wave_elems = elems_per_buf + elems_per_wave * get_warp_id();
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = get_gather_index(idx_ys_start);
// page_idx_ element type can be long_index_t — the pointer
// arithmetic below stays 64-bit by promotion.
const auto page_offset = page_idx_[idx_gather];
// Per-lane 64-bit GLOBAL base pointer. coord.get_offset() is
// the within-bottom-tensor element offset (intra-tile,
// int32-safe by construction); page_offset is the
// page-indirected element offset (potentially > INT32_MAX).
// Pointer arithmetic on DataType* advances by sizeof(DataType)
// and uses 64-bit ptrdiff_t internally.
const long_index_t lane_elem_off =
static_cast<long_index_t>(bottom_tensor_thread_coord.get_offset()) +
static_cast<long_index_t>(page_offset);
// Clamp to in-buffer range to keep `global_load_lds` from
// faulting on tail-padded pages (the original buffer_load
// SRD silently returned 0 for the same OOB voffsets). The
// attention mask zeroes the contribution from these lanes
// at softmax, so the value read here is irrelevant.
constexpr index_t bytes_per_load_ = sizeof(vector_t);
constexpr index_t elems_per_load_ = bytes_per_load_ / sizeof(DataType);
const bool in_range =
(lane_elem_off >= 0) &&
(lane_elem_off + elems_per_load_ <= buf_elems);
const long_index_t safe_off = in_range ? lane_elem_off : 0;
const DataType* per_lane_ptr = base_data_ptr + safe_off;
// Per-issue LDS write target. Wave-uniform; intrinsic emits
// `s_mov_b32 m0, <this>` and the dwordx4 lds-direct write
// lands at m0 + (lane_id * bytes_per_lane). For NumCoord==1
// (the only case exercised by the UA pipeline today) the
// two formulas coincide; we use the monotonically-increasing
// one so each issue lands in its own LDS slot.
constexpr index_t kIssue = iCoord * NumAccessPerCoord + iCoordAccess;
LdsDataType* lds_ptr = lds_base + lds_wave_elems + elems_per_issue * kIssue;
amd_async_global_load_lds_raw<DataType, elems_per_load_, /*byte_offset_imm=*/0>(
lds_ptr, per_lane_ptr, pre_nop_);
// move thread coordinate (no m0_inc — see header comment above)
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)

View File

@@ -109,6 +109,10 @@ struct UnifiedAttentionKernel
ck_tile::index_t split_stride_o_acc = 0;
ck_tile::index_t nhead_stride_lse_acc = 0;
ck_tile::index_t nhead_stride_o_acc = 0;
// Runtime selector for the K/V async-load path in the pipeline. See
// `unified_attention_args::cache_ptr_int32_overflow_possible`.
bool cache_ptr_int32_overflow_possible = false;
};
using Kargs = UnifiedAttentionVarlenKargs;
@@ -150,7 +154,8 @@ struct UnifiedAttentionKernel
ck_tile::index_t split_stride_lse_acc = 0,
ck_tile::index_t split_stride_o_acc = 0,
ck_tile::index_t nhead_stride_lse_acc = 0,
ck_tile::index_t nhead_stride_o_acc = 0)
ck_tile::index_t nhead_stride_o_acc = 0,
bool cache_ptr_int32_overflow_possible = false)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -189,7 +194,8 @@ struct UnifiedAttentionKernel
split_stride_lse_acc,
split_stride_o_acc,
nhead_stride_lse_acc,
nhead_stride_o_acc};
nhead_stride_o_acc,
cache_ptr_int32_overflow_possible};
return kargs;
}
@@ -518,7 +524,8 @@ struct UnifiedAttentionKernel
smem_ptr,
static_cast<long_index_t>(kargs.stride_k_cache_1),
static_cast<long_index_t>(kargs.stride_v_cache_1),
num_queries_per_kv);
num_queries_per_kv,
kargs.cache_ptr_int32_overflow_possible);
auto& o_acc_tile = pipeline_result[number<0>{}];
auto& lse_tile = pipeline_result[number<1>{}];

View File

@@ -194,7 +194,12 @@ struct UnifiedAttentionPipeline
// "fall back to the compile-time `kBlockQ` from `UnifiedAttentionShape`"
// so existing callers don't have to change. The kernel template passes
// the runtime value (from kargs) to remove the static dependency.
const index_t num_queries_per_kv = 0) const
const index_t num_queries_per_kv = 0,
// Caller-supplied flag: set to true when the K/V cache total byte
// size can exceed INT32_MAX. Routes K/V async loads through the
// 64-bit-base `global_load_lds` path (correct but lower throughput).
// False uses the original shared-SRD `buffer_load_dword_lds` path.
const bool cache_ptr_int32_overflow_possible = false) const
{
using namespace ck_tile;
static_assert(
@@ -425,8 +430,17 @@ struct UnifiedAttentionPipeline
const index_t k_thread_n_pos = k_thread_coord[number<0>{}];
const index_t v_thread_n_pos = v_thread_coord[number<0>{}];
statically_indexed_array<index_t, KNRepeat> k_page_offsets;
statically_indexed_array<index_t, VNRepeat> v_page_offsets;
// Page offsets are widened to long_index_t so the `_long` load path
// (global_load_lds, per-lane 64-bit base) can address pools whose
// `num_blocks * page_size * row_stride * sizeof(T)` exceeds INT32_MAX.
// Small-domain values (logical_token, logical_page, within_page,
// phys_page) stay int32 — they're bounded by the per-CTA sequence
// and never overflow. The original `async_load_tile_raw` path
// implicitly narrows this back to int32 when it forwards the value
// through `async_get_vectorized_elements_raw` — that's intentional,
// and safe whenever `cache_ptr_int32_overflow_possible == false`.
statically_indexed_array<long_index_t, KNRepeat> k_page_offsets;
statically_indexed_array<long_index_t, VNRepeat> v_page_offsets;
auto refresh_k_offsets = [&](index_t k_tile_idx) {
static_for<0, KNRepeat, 1>{}([&](auto i) {
@@ -438,7 +452,8 @@ struct UnifiedAttentionPipeline
const index_t phys_page =
block_tables_ptr_[block_table_offset + logical_page];
k_page_offsets(i) =
(phys_page * page_size + within_page) * k_row_stride;
(static_cast<long_index_t>(phys_page) * page_size + within_page) *
k_row_stride;
});
};
auto refresh_v_offsets = [&](index_t v_tile_idx) {
@@ -451,7 +466,8 @@ struct UnifiedAttentionPipeline
const index_t phys_page =
block_tables_ptr_[block_table_offset + logical_page];
v_page_offsets(i) =
(phys_page * page_size + within_page) * v_row_stride;
(static_cast<long_index_t>(phys_page) * page_size + within_page) *
v_row_stride;
});
};
@@ -575,15 +591,46 @@ struct UnifiedAttentionPipeline
// Pass-2: page indirection lives in page_offsets, not in the SRD. We
// refresh the per-iter offsets table and push it to the window via
// update_page_idx(); the SRD itself stays put (no init_raw per iter).
//
// Two load paths, dispatched on the runtime overflow flag:
// - false: `async_load_tile_raw` → `buffer_load_dword_lds` with a
// wave-uniform 4 GB-capped SRD. Faster, but per-lane voffsets
// are int32 so the path is only correct while
// `num_blocks * page_size * row_stride * sizeof(T) ≤ INT32_MAX`.
// - true: `async_load_tile_raw_long` → `global_load_lds_dwordx*`
// with per-lane 64-bit base pointers, lifting the 4 GB limit
// at the cost of lower throughput.
// The branch is on a wave-uniform value, so no execution divergence.
//
// For diagnostic purposes: the wave's N-position span within a
// single buffer_load instruction is `(LaneGroups-1)*NumWarps + 1`.
// When that's > the minimum page_size (≈16) the K-tile distribution
// touches multiple pages per issue, so the small-cache buffer_load
// path still works (per-lane voffsets fit while the cache ≤ 4 GB)
// but the per-issue SRD-rebase optimization (not implemented today)
// wouldn't be applicable — only `global_load_lds` works once the
// cache exceeds 4 GB.
constexpr index_t KWaveSpanInN =
(KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] - 1) *
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}] +
1;
(void)KWaveSpanInN; // currently informational only
auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
else
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
k_block_idx++;
refresh_k_offsets(k_block_idx);
k_dram_window.update_page_idx(k_page_offsets);
};
auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
if(cache_ptr_int32_overflow_possible)
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
else
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
v_block_idx++;
refresh_v_offsets(v_block_idx);
v_dram_window.update_page_idx(v_page_offsets);
@@ -1269,7 +1316,9 @@ struct UnifiedAttentionPipeline
long_index_t v_row_stride = 0,
// Forwards to the full-args operator() so callers can plumb in a
// runtime kBlockQ. See the documentation on that overload.
const index_t num_queries_per_kv = 0) const
const index_t num_queries_per_kv = 0,
// See the doc on the full-args operator().
const bool cache_ptr_int32_overflow_possible = false) const
{
using namespace ck_tile;
@@ -1292,7 +1341,8 @@ struct UnifiedAttentionPipeline
smem_ptr,
k_row_stride,
v_row_stride,
num_queries_per_kv);
num_queries_per_kv,
cache_ptr_int32_overflow_possible);
}
};