From 1f694214346d02b15ca39fd96ceb190e45afc94d Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Fri, 15 May 2026 09:00:43 +0000 Subject: [PATCH] CK-UA: dispatch K/V async load on cache_ptr_int32_overflow_possible The shared-SRD buffer_load_dword_lds path that K_mem_load / V_mem_load use wraps the per-lane voffset (int32 bytes) once num_blocks * page_size * row_stride * sizeof(T) > INT32_MAX, silently returning wrong data on large paged-KV pools (e.g. >4 GB caches). Add a second path, async_load_tile_raw_long, that issues the same load via __builtin_amdgcn_global_load_lds with per-lane 64-bit base pointers, lifting both 4 GB limits (SRD size + voffset). Per-issue LDS pointers are computed explicitly because the intrinsic sets m0 itself, so the old m0_set / m0_inc bookkeeping doesn't apply. The path also clamps lane_elem_off to the live buffer range to mimic the original SRD's hardware OOB behaviour. Dispatch is a wave-uniform runtime branch on a new cache_ptr_int32_overflow_possible flag plumbed from unified_attention_args through MakeKargs into the pipeline operator(). Small caches keep the original buffer_load throughput; only the (rare) >4 GB cache pays the global_load_lds cost. k_page_offsets / v_page_offsets are widened to long_index_t. The original buffer_load path implicitly narrows back to int32 when forwarding through async_get_vectorized_elements_raw, which is intentional and safe whenever the overflow flag is false. For diagnostics, also derive a constexpr KWaveSpanInN = (LaneGroups - 1) * NumWarps + 1 inside the pipeline; when this exceeds page_size a single buffer_load spans multiple random pages, so the per-issue SRD-rebase optimisation (not implemented yet) would not apply even on a sub-4 GB cache. Informational only today. Test: ua-test-scripts correctness sweep (245/245 pass), plus test_single_shape.py -b 32 -sq 8192 -sk 120000 -hq 64 -hk 8 -d 64 \ --num-blocks 1200000 --block-size 16 --test which previously returned wrong data due to the int32 wrap and now passes with max abs diff 1.22e-04 vs Triton. Co-authored-by: Cursor --- .../unified_attention.hpp | 8 + .../unified_attention_impl.hpp | 3 +- .../core/arch/amd_buffer_addressing.hpp | 59 +++++++ .../arch/amd_buffer_addressing_builtins.hpp | 61 +++++++ include/ck_tile/core/tensor/load_tile.hpp | 21 +++ .../core/tensor/tile_scatter_gather.hpp | 151 ++++++++++++++++++ .../kernel/unified_attention_kernel.hpp | 13 +- .../pipeline/unified_attention_pipeline.hpp | 68 ++++++-- 8 files changed, 371 insertions(+), 13 deletions(-) diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp index 023fc3be4e..e5c1015b50 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -68,6 +68,14 @@ struct unified_attention_args index_t num_seqs; // number of batches for q index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown) + // Set to true when the K/V cache is large enough that an int32 byte + // offset into it can overflow (i.e. when + // num_blocks * page_size * num_kv_heads * head_dim * sizeof(T) > INT32_MAX + // ). When true, the pipeline routes K/V async loads through + // `global_load_lds` (per-lane 64-bit base ptr); when false, it uses the + // faster `buffer_load_dword_lds` path with a shared 4 GB-capped SRD. + bool cache_ptr_int32_overflow_possible = false; + // KV-segment parallelism (split-KV). When num_splits == 1, the kernel // writes to o_ptr as usual. When num_splits > 1, the kernel is launched // with a 3D grid whose z-dim is num_splits — each CTA computes its own diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 8f736dfe01..f4b3b2ab72 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -328,7 +328,8 @@ float unified_attention_kernel_launch(const unified_attention_args& args, args.split_stride_lse_acc, args.split_stride_o_acc, args.nhead_stride_lse_acc, - args.nhead_stride_o_acc); + args.nhead_stride_o_acc, + args.cache_ptr_int32_overflow_possible); dim3 grids; if constexpr(UseDecodeGrid) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index a32f26dadf..7295c7ab56 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -2861,6 +2861,65 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, bool_constant{}); } +// ============================================================================= +// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer. +// +// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD +// (`int32x4_t` resource descriptor) entirely: +// - SRD's `size` field is uint32_t (max ~4 GB pool). Caches above that wrap. +// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap. +// Replacing the underlying HW instruction with `global_load_lds` (per-lane +// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits. +// Required for paged-KV caches whose `num_blocks * page_size * row_stride * +// sizeof(T)` exceeds INT32_MAX (e.g. very-long-context decode pools). +// +// Caveats: +// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer +// is valid (in our pipeline use, the page_table lookup guarantees this). +// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`. +// Older arches would need a `global_load + ds_write` fallback. +// ============================================================================= +template +CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem, + const T* base_ptr_64, + bool_constant = {}) +{ + constexpr index_t bytes = sizeof(T) * N; + + static_assert(bytes == 4 || bytes == 12 || bytes == 16, + "global_load_lds: only dword / dwordx3 / dwordx4 supported on gfx950"); + static_assert(-4096 <= byte_offset_imm && byte_offset_imm <= 4095, + "global_load_lds: byte_offset_imm must fit in 13-bit signed"); + + // C-style cast injects the address-space attribute the intrinsic expects + // (addrspace(1) for global, addrspace(3) for LDS) without losing const. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + const __attribute__((address_space(1))) void* gptr = + (const __attribute__((address_space(1))) void*)base_ptr_64; + __attribute__((address_space(3))) void* lptr = + (__attribute__((address_space(3))) void*)smem; +#pragma clang diagnostic pop + + if constexpr(pre_nop) + asm volatile("s_nop 4\n" ::: "memory"); + + // Front-end requires `size`, `offset` and `aux` to be ImmArg / integer + // literals. A switch on the constexpr `bytes` value lets each branch + // pass the literal directly. + constexpr int kCoherence = static_cast(coherence); + if constexpr(bytes == 16) + __builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence); + else if constexpr(bytes == 12) + __builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence); + else /* bytes == 4 */ + __builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence); +} + // This version support buffer resource as input arg template {}); } +// ============================================================================= +// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer. +// +// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD +// (`int32x4_t` / `__amdgpu_buffer_rsrc_t` resource descriptor) entirely. The +// buffer_load path has two 32-bit limits at the HW boundary: +// - SRD `size` field is uint32_t (max ~4 GB pool). Caches above that wrap. +// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap. +// Replacing the underlying HW instruction with `global_load_lds` (per-lane +// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits. +// Required for paged-KV caches whose +// `num_blocks * page_size * row_stride * sizeof(T)` exceeds INT32_MAX (e.g. +// very-long-context decode pools). +// +// Caveats: +// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer +// is valid (in our pipeline use, the page_table lookup guarantees this). +// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`. Older +// arches would need a `global_load + ds_write` fallback. +// ============================================================================= +template +CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem, + const T* base_ptr_64, + bool_constant = {}) +{ + constexpr index_t bytes = sizeof(T) * N; + + static_assert(bytes == 4 || bytes == 12 || bytes == 16, + "global_load_lds: only dword / dwordx3 / dwordx4 supported on gfx950"); + static_assert(-4096 <= byte_offset_imm && byte_offset_imm <= 4095, + "global_load_lds: byte_offset_imm must fit in 13-bit signed"); + + // C-style cast injects the address-space attribute the intrinsic expects + // (addrspace(1) for global, addrspace(3) for LDS) without losing const. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + const __attribute__((address_space(1))) void* gptr = + (const __attribute__((address_space(1))) void*)base_ptr_64; + __attribute__((address_space(3))) void* lptr = + (__attribute__((address_space(3))) void*)smem; +#pragma clang diagnostic pop + + if constexpr(pre_nop) + asm volatile("s_nop 4\n" ::: "memory"); + + // Front-end requires `size`, `offset` and `aux` to be ImmArg / integer + // literals. A switch on the constexpr `bytes` value lets each branch + // pass the literal directly. + constexpr int kCoherence = static_cast(coherence); + if constexpr(bytes == 16) + __builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence); + else if constexpr(bytes == 12) + __builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence); + else /* bytes == 4 */ + __builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence); +} + // This version support buffer resource as input arg template {}); } +// Variant of async_load_tile_raw that dispatches to async_load_raw_long +// (global_load_lds path with per-lane 64-bit base pointers). Only valid for +// tile_scatter_gather windows whose PageIdxArray element type supports +// 64-bit values (e.g. long_index_t). +template +CK_TILE_DEVICE void async_load_tile_raw_long(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + number = {}, + bool_constant = {}, + bool_constant = {}) +{ + tile_window.async_load_raw_long(lds_tile, + number{}, + bool_constant{}, + bool_constant{}); +} + CK_TILE_DEVICE void async_load_fence(index_t cnt = 0) { asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index aa29345892..2730310e20 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" @@ -1044,6 +1045,156 @@ struct tile_scatter_gather } } + // --------------------------------------------------------------------- + // async_load_raw_long: variant of async_load_raw that issues the per-tile + // gather load via `amd_async_global_load_lds_raw` (i.e. AMDGCN + // `global_load_lds_dwordx*`) rather than `buffer_load_dword_lds`. + // + // Identical iteration structure, m0/LDS-slot bookkeeping, and SFC walk + // as async_load_raw — only the HW load instruction is swapped. The page + // indirection is folded into a per-lane 64-bit base pointer, lifting + // both 4 GB limits in the buffer_load path (SRD `size` field is uint32_t, + // per-lane voffset is int32). PageIdxArray's element type can therefore + // be `long_index_t` (caller's responsibility). + // + // Why not per-issue SRD rebase? In the K/V tile distributions emitted + // by CK-UA today, a single wave-wide buffer_load_dword* spans + // (LaneGroups) different N-positions, which for the prefill configs + // (NumWarps≥2) can map to several different pages within one issue. + // For paged-KV caches > 4 GB, those pages can be ≫ 4 GB apart in the + // global K buffer, exceeding the 32-bit voffset / 32-bit SRD-size + // range. Only a per-lane 64-bit base pointer (i.e. global_load_lds) + // can address all those lanes from a single instruction. + // + // OOB note: this path drops the SRD's hardware OOB clamp. Caller must + // ensure `page_idx_` only references live pages (true in the paged-KV + // use-case where block_tables are populated from a valid allocator). + // --------------------------------------------------------------------- + template + CK_TILE_DEVICE auto async_load_raw_long(LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + // The per-tile LDS layout in elements (not bytes). The new + // `global_load_lds_*` path differs from the `buffer_load_dword_lds` + // path here: the LLVM intrinsic implicitly sets `m0` from its + // `lptr` argument every call, so the manual `m0_set / m0_inc` + // bookkeeping used by `async_load_raw` would be silently + // overwritten. Instead, we compute the per-issue LDS element offset + // and add it to the LDS base pointer on each call — the compiler + // emits a fresh `s_mov_b32 m0, ...` per load with the right value. + const index_t elems_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})); + + const index_t elems_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) - + elems_per_buf; + + const index_t elems_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) - + elems_per_buf; + + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + // bf16/fp16/etc. element-typed global ptr base for this tile-window. + const DataType* base_data_ptr = bottom_tensor_view_.get_buffer_view().p_data_; + // Element count in the underlying buffer — used to clamp per-lane + // pointers that go past the live range, mimicking the SRD's OOB + // semantics on the original `buffer_load_dword_lds` path. + const long_index_t buf_elems = static_cast( + bottom_tensor_view_.get_buffer_view().buffer_size_); + LdsDataType* lds_base = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // Wave / warp-group offset into LDS, computed once. + const index_t lds_wave_elems = elems_per_buf + elems_per_wave * get_warp_id(); + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); + + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = get_gather_index(idx_ys_start); + // page_idx_ element type can be long_index_t — the pointer + // arithmetic below stays 64-bit by promotion. + const auto page_offset = page_idx_[idx_gather]; + + // Per-lane 64-bit GLOBAL base pointer. coord.get_offset() is + // the within-bottom-tensor element offset (intra-tile, + // int32-safe by construction); page_offset is the + // page-indirected element offset (potentially > INT32_MAX). + // Pointer arithmetic on DataType* advances by sizeof(DataType) + // and uses 64-bit ptrdiff_t internally. + const long_index_t lane_elem_off = + static_cast(bottom_tensor_thread_coord.get_offset()) + + static_cast(page_offset); + // Clamp to in-buffer range to keep `global_load_lds` from + // faulting on tail-padded pages (the original buffer_load + // SRD silently returned 0 for the same OOB voffsets). The + // attention mask zeroes the contribution from these lanes + // at softmax, so the value read here is irrelevant. + constexpr index_t bytes_per_load_ = sizeof(vector_t); + constexpr index_t elems_per_load_ = bytes_per_load_ / sizeof(DataType); + const bool in_range = + (lane_elem_off >= 0) && + (lane_elem_off + elems_per_load_ <= buf_elems); + const long_index_t safe_off = in_range ? lane_elem_off : 0; + const DataType* per_lane_ptr = base_data_ptr + safe_off; + + // Per-issue LDS write target. Wave-uniform; intrinsic emits + // `s_mov_b32 m0, ` and the dwordx4 lds-direct write + // lands at m0 + (lane_id * bytes_per_lane). For NumCoord==1 + // (the only case exercised by the UA pipeline today) the + // two formulas coincide; we use the monotonically-increasing + // one so each issue lands in its own LDS slot. + constexpr index_t kIssue = iCoord * NumAccessPerCoord + iCoordAccess; + LdsDataType* lds_ptr = lds_base + lds_wave_elems + elems_per_issue * kIssue; + + amd_async_global_load_lds_raw( + lds_ptr, per_lane_ptr, pre_nop_); + + // move thread coordinate (no m0_inc — see header comment above) + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; } CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 856cba024f..92fe209fcd 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -109,6 +109,10 @@ struct UnifiedAttentionKernel ck_tile::index_t split_stride_o_acc = 0; ck_tile::index_t nhead_stride_lse_acc = 0; ck_tile::index_t nhead_stride_o_acc = 0; + + // Runtime selector for the K/V async-load path in the pipeline. See + // `unified_attention_args::cache_ptr_int32_overflow_possible`. + bool cache_ptr_int32_overflow_possible = false; }; using Kargs = UnifiedAttentionVarlenKargs; @@ -150,7 +154,8 @@ struct UnifiedAttentionKernel ck_tile::index_t split_stride_lse_acc = 0, ck_tile::index_t split_stride_o_acc = 0, ck_tile::index_t nhead_stride_lse_acc = 0, - ck_tile::index_t nhead_stride_o_acc = 0) + ck_tile::index_t nhead_stride_o_acc = 0, + bool cache_ptr_int32_overflow_possible = false) { Kargs kargs{{q_ptr, k_ptr, @@ -189,7 +194,8 @@ struct UnifiedAttentionKernel split_stride_lse_acc, split_stride_o_acc, nhead_stride_lse_acc, - nhead_stride_o_acc}; + nhead_stride_o_acc, + cache_ptr_int32_overflow_possible}; return kargs; } @@ -518,7 +524,8 @@ struct UnifiedAttentionKernel smem_ptr, static_cast(kargs.stride_k_cache_1), static_cast(kargs.stride_v_cache_1), - num_queries_per_kv); + num_queries_per_kv, + kargs.cache_ptr_int32_overflow_possible); auto& o_acc_tile = pipeline_result[number<0>{}]; auto& lse_tile = pipeline_result[number<1>{}]; diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 0127939066..ee59d8c8bd 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -194,7 +194,12 @@ struct UnifiedAttentionPipeline // "fall back to the compile-time `kBlockQ` from `UnifiedAttentionShape`" // so existing callers don't have to change. The kernel template passes // the runtime value (from kargs) to remove the static dependency. - const index_t num_queries_per_kv = 0) const + const index_t num_queries_per_kv = 0, + // Caller-supplied flag: set to true when the K/V cache total byte + // size can exceed INT32_MAX. Routes K/V async loads through the + // 64-bit-base `global_load_lds` path (correct but lower throughput). + // False uses the original shared-SRD `buffer_load_dword_lds` path. + const bool cache_ptr_int32_overflow_possible = false) const { using namespace ck_tile; static_assert( @@ -425,8 +430,17 @@ struct UnifiedAttentionPipeline const index_t k_thread_n_pos = k_thread_coord[number<0>{}]; const index_t v_thread_n_pos = v_thread_coord[number<0>{}]; - statically_indexed_array k_page_offsets; - statically_indexed_array v_page_offsets; + // Page offsets are widened to long_index_t so the `_long` load path + // (global_load_lds, per-lane 64-bit base) can address pools whose + // `num_blocks * page_size * row_stride * sizeof(T)` exceeds INT32_MAX. + // Small-domain values (logical_token, logical_page, within_page, + // phys_page) stay int32 — they're bounded by the per-CTA sequence + // and never overflow. The original `async_load_tile_raw` path + // implicitly narrows this back to int32 when it forwards the value + // through `async_get_vectorized_elements_raw` — that's intentional, + // and safe whenever `cache_ptr_int32_overflow_possible == false`. + statically_indexed_array k_page_offsets; + statically_indexed_array v_page_offsets; auto refresh_k_offsets = [&](index_t k_tile_idx) { static_for<0, KNRepeat, 1>{}([&](auto i) { @@ -438,7 +452,8 @@ struct UnifiedAttentionPipeline const index_t phys_page = block_tables_ptr_[block_table_offset + logical_page]; k_page_offsets(i) = - (phys_page * page_size + within_page) * k_row_stride; + (static_cast(phys_page) * page_size + within_page) * + k_row_stride; }); }; auto refresh_v_offsets = [&](index_t v_tile_idx) { @@ -451,7 +466,8 @@ struct UnifiedAttentionPipeline const index_t phys_page = block_tables_ptr_[block_table_offset + logical_page]; v_page_offsets(i) = - (phys_page * page_size + within_page) * v_row_stride; + (static_cast(phys_page) * page_size + within_page) * + v_row_stride; }); }; @@ -575,15 +591,46 @@ struct UnifiedAttentionPipeline // Pass-2: page indirection lives in page_offsets, not in the SRD. We // refresh the per-iter offsets table and push it to the window via // update_page_idx(); the SRD itself stays put (no init_raw per iter). + // + // Two load paths, dispatched on the runtime overflow flag: + // - false: `async_load_tile_raw` → `buffer_load_dword_lds` with a + // wave-uniform 4 GB-capped SRD. Faster, but per-lane voffsets + // are int32 so the path is only correct while + // `num_blocks * page_size * row_stride * sizeof(T) ≤ INT32_MAX`. + // - true: `async_load_tile_raw_long` → `global_load_lds_dwordx*` + // with per-lane 64-bit base pointers, lifting the 4 GB limit + // at the cost of lower throughput. + // The branch is on a wave-uniform value, so no execution divergence. + // + // For diagnostic purposes: the wave's N-position span within a + // single buffer_load instruction is `(LaneGroups-1)*NumWarps + 1`. + // When that's > the minimum page_size (≈16) the K-tile distribution + // touches multiple pages per issue, so the small-cache buffer_load + // path still works (per-lane voffsets fit while the cache ≤ 4 GB) + // but the per-issue SRD-rebase optimization (not implemented today) + // wouldn't be applicable — only `global_load_lds` works once the + // cache exceeds 4 GB. + constexpr index_t KWaveSpanInN = + (KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] - 1) * + KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}] + + 1; + (void)KWaveSpanInN; // currently informational only + auto K_mem_load = [&](auto k_lds_write_idx) { - async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + if(cache_ptr_int32_overflow_possible) + async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window); + else + async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); k_block_idx++; refresh_k_offsets(k_block_idx); k_dram_window.update_page_idx(k_page_offsets); }; auto V_mem_load = [&](auto v_lds_write_idx) { - async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + if(cache_ptr_int32_overflow_possible) + async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window); + else + async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); v_block_idx++; refresh_v_offsets(v_block_idx); v_dram_window.update_page_idx(v_page_offsets); @@ -1269,7 +1316,9 @@ struct UnifiedAttentionPipeline long_index_t v_row_stride = 0, // Forwards to the full-args operator() so callers can plumb in a // runtime kBlockQ. See the documentation on that overload. - const index_t num_queries_per_kv = 0) const + const index_t num_queries_per_kv = 0, + // See the doc on the full-args operator(). + const bool cache_ptr_int32_overflow_possible = false) const { using namespace ck_tile; @@ -1292,7 +1341,8 @@ struct UnifiedAttentionPipeline smem_ptr, k_row_stride, v_row_stride, - num_queries_per_kv); + num_queries_per_kv, + cache_ptr_int32_overflow_possible); } };