mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
CK-UA: dispatch K/V async load on cache_ptr_int32_overflow_possible
The shared-SRD buffer_load_dword_lds path that K_mem_load / V_mem_load use
wraps the per-lane voffset (int32 bytes) once
num_blocks * page_size * row_stride * sizeof(T) > INT32_MAX,
silently returning wrong data on large paged-KV pools (e.g. >4 GB caches).
Add a second path, async_load_tile_raw_long, that issues the same load via
__builtin_amdgcn_global_load_lds with per-lane 64-bit base pointers, lifting
both 4 GB limits (SRD size + voffset). Per-issue LDS pointers are computed
explicitly because the intrinsic sets m0 itself, so the old m0_set / m0_inc
bookkeeping doesn't apply. The path also clamps lane_elem_off to the live
buffer range to mimic the original SRD's hardware OOB behaviour.
Dispatch is a wave-uniform runtime branch on a new
cache_ptr_int32_overflow_possible flag plumbed from
unified_attention_args through MakeKargs into the pipeline operator().
Small caches keep the original buffer_load throughput; only the (rare)
>4 GB cache pays the global_load_lds cost.
k_page_offsets / v_page_offsets are widened to long_index_t. The original
buffer_load path implicitly narrows back to int32 when forwarding through
async_get_vectorized_elements_raw, which is intentional and safe whenever
the overflow flag is false.
For diagnostics, also derive a constexpr KWaveSpanInN =
(LaneGroups - 1) * NumWarps + 1 inside the pipeline; when this exceeds
page_size a single buffer_load spans multiple random pages, so the
per-issue SRD-rebase optimisation (not implemented yet) would not apply
even on a sub-4 GB cache. Informational only today.
Test: ua-test-scripts correctness sweep (245/245 pass), plus
test_single_shape.py -b 32 -sq 8192 -sk 120000 -hq 64 -hk 8 -d 64 \
--num-blocks 1200000 --block-size 16 --test
which previously returned wrong data due to the int32 wrap and now passes
with max abs diff 1.22e-04 vs Triton.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -68,6 +68,14 @@ struct unified_attention_args
|
||||
index_t num_seqs; // number of batches for q
|
||||
index_t max_seqlen_q = 0; // max query length across all batches (0 = unknown)
|
||||
|
||||
// Set to true when the K/V cache is large enough that an int32 byte
|
||||
// offset into it can overflow (i.e. when
|
||||
// num_blocks * page_size * num_kv_heads * head_dim * sizeof(T) > INT32_MAX
|
||||
// ). When true, the pipeline routes K/V async loads through
|
||||
// `global_load_lds` (per-lane 64-bit base ptr); when false, it uses the
|
||||
// faster `buffer_load_dword_lds` path with a shared 4 GB-capped SRD.
|
||||
bool cache_ptr_int32_overflow_possible = false;
|
||||
|
||||
// KV-segment parallelism (split-KV). When num_splits == 1, the kernel
|
||||
// writes to o_ptr as usual. When num_splits > 1, the kernel is launched
|
||||
// with a 3D grid whose z-dim is num_splits — each CTA computes its own
|
||||
|
||||
@@ -328,7 +328,8 @@ float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc);
|
||||
args.nhead_stride_o_acc,
|
||||
args.cache_ptr_int32_overflow_possible);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(UseDecodeGrid)
|
||||
|
||||
@@ -2861,6 +2861,65 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer.
|
||||
//
|
||||
// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD
|
||||
// (`int32x4_t` resource descriptor) entirely:
|
||||
// - SRD's `size` field is uint32_t (max ~4 GB pool). Caches above that wrap.
|
||||
// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap.
|
||||
// Replacing the underlying HW instruction with `global_load_lds` (per-lane
|
||||
// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits.
|
||||
// Required for paged-KV caches whose `num_blocks * page_size * row_stride *
|
||||
// sizeof(T)` exceeds INT32_MAX (e.g. very-long-context decode pools).
|
||||
//
|
||||
// Caveats:
|
||||
// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer
|
||||
// is valid (in our pipeline use, the page_table lookup guarantees this).
|
||||
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`.
|
||||
// Older arches would need a `global_load + ds_write` fallback.
|
||||
// =============================================================================
|
||||
template <typename T,
|
||||
index_t N,
|
||||
index_t byte_offset_imm = 0, // 13-bit signed
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
|
||||
const T* base_ptr_64,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"global_load_lds: only dword / dwordx3 / dwordx4 supported on gfx950");
|
||||
static_assert(-4096 <= byte_offset_imm && byte_offset_imm <= 4095,
|
||||
"global_load_lds: byte_offset_imm must fit in 13-bit signed");
|
||||
|
||||
// C-style cast injects the address-space attribute the intrinsic expects
|
||||
// (addrspace(1) for global, addrspace(3) for LDS) without losing const.
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
const __attribute__((address_space(1))) void* gptr =
|
||||
(const __attribute__((address_space(1))) void*)base_ptr_64;
|
||||
__attribute__((address_space(3))) void* lptr =
|
||||
(__attribute__((address_space(3))) void*)smem;
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n" ::: "memory");
|
||||
|
||||
// Front-end requires `size`, `offset` and `aux` to be ImmArg / integer
|
||||
// literals. A switch on the constexpr `bytes` value lets each branch
|
||||
// pass the literal directly.
|
||||
constexpr int kCoherence = static_cast<int>(coherence);
|
||||
if constexpr(bytes == 16)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
|
||||
else if constexpr(bytes == 12)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
|
||||
else /* bytes == 4 */
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
template <typename T,
|
||||
index_t N,
|
||||
|
||||
@@ -2686,6 +2686,67 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer.
|
||||
//
|
||||
// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD
|
||||
// (`int32x4_t` / `__amdgpu_buffer_rsrc_t` resource descriptor) entirely. The
|
||||
// buffer_load path has two 32-bit limits at the HW boundary:
|
||||
// - SRD `size` field is uint32_t (max ~4 GB pool). Caches above that wrap.
|
||||
// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap.
|
||||
// Replacing the underlying HW instruction with `global_load_lds` (per-lane
|
||||
// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits.
|
||||
// Required for paged-KV caches whose
|
||||
// `num_blocks * page_size * row_stride * sizeof(T)` exceeds INT32_MAX (e.g.
|
||||
// very-long-context decode pools).
|
||||
//
|
||||
// Caveats:
|
||||
// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer
|
||||
// is valid (in our pipeline use, the page_table lookup guarantees this).
|
||||
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`. Older
|
||||
// arches would need a `global_load + ds_write` fallback.
|
||||
// =============================================================================
|
||||
template <typename T,
|
||||
index_t N,
|
||||
index_t byte_offset_imm = 0, // 13-bit signed
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
|
||||
const T* base_ptr_64,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
|
||||
"global_load_lds: only dword / dwordx3 / dwordx4 supported on gfx950");
|
||||
static_assert(-4096 <= byte_offset_imm && byte_offset_imm <= 4095,
|
||||
"global_load_lds: byte_offset_imm must fit in 13-bit signed");
|
||||
|
||||
// C-style cast injects the address-space attribute the intrinsic expects
|
||||
// (addrspace(1) for global, addrspace(3) for LDS) without losing const.
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
const __attribute__((address_space(1))) void* gptr =
|
||||
(const __attribute__((address_space(1))) void*)base_ptr_64;
|
||||
__attribute__((address_space(3))) void* lptr =
|
||||
(__attribute__((address_space(3))) void*)smem;
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
if constexpr(pre_nop)
|
||||
asm volatile("s_nop 4\n" ::: "memory");
|
||||
|
||||
// Front-end requires `size`, `offset` and `aux` to be ImmArg / integer
|
||||
// literals. A switch on the constexpr `bytes` value lets each branch
|
||||
// pass the literal directly.
|
||||
constexpr int kCoherence = static_cast<int>(coherence);
|
||||
if constexpr(bytes == 16)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
|
||||
else if constexpr(bytes == 12)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
|
||||
else /* bytes == 4 */
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
template <typename T,
|
||||
index_t N,
|
||||
|
||||
@@ -193,6 +193,27 @@ CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// Variant of async_load_tile_raw that dispatches to async_load_raw_long
|
||||
// (global_load_lds path with per-lane 64-bit base pointers). Only valid for
|
||||
// tile_scatter_gather windows whose PageIdxArray element type supports
|
||||
// 64-bit values (e.g. long_index_t).
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_load_tile_raw_long(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.async_load_raw_long(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
@@ -1044,6 +1045,156 @@ struct tile_scatter_gather
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// async_load_raw_long: variant of async_load_raw that issues the per-tile
|
||||
// gather load via `amd_async_global_load_lds_raw` (i.e. AMDGCN
|
||||
// `global_load_lds_dwordx*`) rather than `buffer_load_dword_lds`.
|
||||
//
|
||||
// Identical iteration structure, m0/LDS-slot bookkeeping, and SFC walk
|
||||
// as async_load_raw — only the HW load instruction is swapped. The page
|
||||
// indirection is folded into a per-lane 64-bit base pointer, lifting
|
||||
// both 4 GB limits in the buffer_load path (SRD `size` field is uint32_t,
|
||||
// per-lane voffset is int32). PageIdxArray's element type can therefore
|
||||
// be `long_index_t` (caller's responsibility).
|
||||
//
|
||||
// Why not per-issue SRD rebase? In the K/V tile distributions emitted
|
||||
// by CK-UA today, a single wave-wide buffer_load_dword* spans
|
||||
// (LaneGroups) different N-positions, which for the prefill configs
|
||||
// (NumWarps≥2) can map to several different pages within one issue.
|
||||
// For paged-KV caches > 4 GB, those pages can be ≫ 4 GB apart in the
|
||||
// global K buffer, exceeding the 32-bit voffset / 32-bit SRD-size
|
||||
// range. Only a per-lane 64-bit base pointer (i.e. global_load_lds)
|
||||
// can address all those lanes from a single instruction.
|
||||
//
|
||||
// OOB note: this path drops the SRD's hardware OOB clamp. Caller must
|
||||
// ensure `page_idx_` only references live pages (true in the paged-KV
|
||||
// use-case where block_tables are populated from a valid allocator).
|
||||
// ---------------------------------------------------------------------
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw_long(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// The per-tile LDS layout in elements (not bytes). The new
|
||||
// `global_load_lds_*` path differs from the `buffer_load_dword_lds`
|
||||
// path here: the LLVM intrinsic implicitly sets `m0` from its
|
||||
// `lptr` argument every call, so the manual `m0_set / m0_inc`
|
||||
// bookkeeping used by `async_load_raw` would be silently
|
||||
// overwritten. Instead, we compute the per-issue LDS element offset
|
||||
// and add it to the LDS base pointer on each call — the compiler
|
||||
// emits a fresh `s_mov_b32 m0, ...` per load with the right value.
|
||||
const index_t elems_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
const index_t elems_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
elems_per_buf;
|
||||
|
||||
const index_t elems_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
elems_per_buf;
|
||||
|
||||
using Traits = load_store_traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// bf16/fp16/etc. element-typed global ptr base for this tile-window.
|
||||
const DataType* base_data_ptr = bottom_tensor_view_.get_buffer_view().p_data_;
|
||||
// Element count in the underlying buffer — used to clamp per-lane
|
||||
// pointers that go past the live range, mimicking the SRD's OOB
|
||||
// semantics on the original `buffer_load_dword_lds` path.
|
||||
const long_index_t buf_elems = static_cast<long_index_t>(
|
||||
bottom_tensor_view_.get_buffer_view().buffer_size_);
|
||||
LdsDataType* lds_base = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
// Wave / warp-group offset into LDS, computed once.
|
||||
const index_t lds_wave_elems = elems_per_buf + elems_per_wave * get_warp_id();
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
constexpr auto pre_nop_ = [&]() {
|
||||
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = get_gather_index(idx_ys_start);
|
||||
// page_idx_ element type can be long_index_t — the pointer
|
||||
// arithmetic below stays 64-bit by promotion.
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// Per-lane 64-bit GLOBAL base pointer. coord.get_offset() is
|
||||
// the within-bottom-tensor element offset (intra-tile,
|
||||
// int32-safe by construction); page_offset is the
|
||||
// page-indirected element offset (potentially > INT32_MAX).
|
||||
// Pointer arithmetic on DataType* advances by sizeof(DataType)
|
||||
// and uses 64-bit ptrdiff_t internally.
|
||||
const long_index_t lane_elem_off =
|
||||
static_cast<long_index_t>(bottom_tensor_thread_coord.get_offset()) +
|
||||
static_cast<long_index_t>(page_offset);
|
||||
// Clamp to in-buffer range to keep `global_load_lds` from
|
||||
// faulting on tail-padded pages (the original buffer_load
|
||||
// SRD silently returned 0 for the same OOB voffsets). The
|
||||
// attention mask zeroes the contribution from these lanes
|
||||
// at softmax, so the value read here is irrelevant.
|
||||
constexpr index_t bytes_per_load_ = sizeof(vector_t);
|
||||
constexpr index_t elems_per_load_ = bytes_per_load_ / sizeof(DataType);
|
||||
const bool in_range =
|
||||
(lane_elem_off >= 0) &&
|
||||
(lane_elem_off + elems_per_load_ <= buf_elems);
|
||||
const long_index_t safe_off = in_range ? lane_elem_off : 0;
|
||||
const DataType* per_lane_ptr = base_data_ptr + safe_off;
|
||||
|
||||
// Per-issue LDS write target. Wave-uniform; intrinsic emits
|
||||
// `s_mov_b32 m0, <this>` and the dwordx4 lds-direct write
|
||||
// lands at m0 + (lane_id * bytes_per_lane). For NumCoord==1
|
||||
// (the only case exercised by the UA pipeline today) the
|
||||
// two formulas coincide; we use the monotonically-increasing
|
||||
// one so each issue lands in its own LDS slot.
|
||||
constexpr index_t kIssue = iCoord * NumAccessPerCoord + iCoordAccess;
|
||||
LdsDataType* lds_ptr = lds_base + lds_wave_elems + elems_per_issue * kIssue;
|
||||
|
||||
amd_async_global_load_lds_raw<DataType, elems_per_load_, /*byte_offset_imm=*/0>(
|
||||
lds_ptr, per_lane_ptr, pre_nop_);
|
||||
|
||||
// move thread coordinate (no m0_inc — see header comment above)
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; },
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
|
||||
|
||||
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
|
||||
|
||||
@@ -109,6 +109,10 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t split_stride_o_acc = 0;
|
||||
ck_tile::index_t nhead_stride_lse_acc = 0;
|
||||
ck_tile::index_t nhead_stride_o_acc = 0;
|
||||
|
||||
// Runtime selector for the K/V async-load path in the pipeline. See
|
||||
// `unified_attention_args::cache_ptr_int32_overflow_possible`.
|
||||
bool cache_ptr_int32_overflow_possible = false;
|
||||
};
|
||||
|
||||
using Kargs = UnifiedAttentionVarlenKargs;
|
||||
@@ -150,7 +154,8 @@ struct UnifiedAttentionKernel
|
||||
ck_tile::index_t split_stride_lse_acc = 0,
|
||||
ck_tile::index_t split_stride_o_acc = 0,
|
||||
ck_tile::index_t nhead_stride_lse_acc = 0,
|
||||
ck_tile::index_t nhead_stride_o_acc = 0)
|
||||
ck_tile::index_t nhead_stride_o_acc = 0,
|
||||
bool cache_ptr_int32_overflow_possible = false)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -189,7 +194,8 @@ struct UnifiedAttentionKernel
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc};
|
||||
nhead_stride_o_acc,
|
||||
cache_ptr_int32_overflow_possible};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -518,7 +524,8 @@ struct UnifiedAttentionKernel
|
||||
smem_ptr,
|
||||
static_cast<long_index_t>(kargs.stride_k_cache_1),
|
||||
static_cast<long_index_t>(kargs.stride_v_cache_1),
|
||||
num_queries_per_kv);
|
||||
num_queries_per_kv,
|
||||
kargs.cache_ptr_int32_overflow_possible);
|
||||
auto& o_acc_tile = pipeline_result[number<0>{}];
|
||||
auto& lse_tile = pipeline_result[number<1>{}];
|
||||
|
||||
|
||||
@@ -194,7 +194,12 @@ struct UnifiedAttentionPipeline
|
||||
// "fall back to the compile-time `kBlockQ` from `UnifiedAttentionShape`"
|
||||
// so existing callers don't have to change. The kernel template passes
|
||||
// the runtime value (from kargs) to remove the static dependency.
|
||||
const index_t num_queries_per_kv = 0) const
|
||||
const index_t num_queries_per_kv = 0,
|
||||
// Caller-supplied flag: set to true when the K/V cache total byte
|
||||
// size can exceed INT32_MAX. Routes K/V async loads through the
|
||||
// 64-bit-base `global_load_lds` path (correct but lower throughput).
|
||||
// False uses the original shared-SRD `buffer_load_dword_lds` path.
|
||||
const bool cache_ptr_int32_overflow_possible = false) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
static_assert(
|
||||
@@ -425,8 +430,17 @@ struct UnifiedAttentionPipeline
|
||||
const index_t k_thread_n_pos = k_thread_coord[number<0>{}];
|
||||
const index_t v_thread_n_pos = v_thread_coord[number<0>{}];
|
||||
|
||||
statically_indexed_array<index_t, KNRepeat> k_page_offsets;
|
||||
statically_indexed_array<index_t, VNRepeat> v_page_offsets;
|
||||
// Page offsets are widened to long_index_t so the `_long` load path
|
||||
// (global_load_lds, per-lane 64-bit base) can address pools whose
|
||||
// `num_blocks * page_size * row_stride * sizeof(T)` exceeds INT32_MAX.
|
||||
// Small-domain values (logical_token, logical_page, within_page,
|
||||
// phys_page) stay int32 — they're bounded by the per-CTA sequence
|
||||
// and never overflow. The original `async_load_tile_raw` path
|
||||
// implicitly narrows this back to int32 when it forwards the value
|
||||
// through `async_get_vectorized_elements_raw` — that's intentional,
|
||||
// and safe whenever `cache_ptr_int32_overflow_possible == false`.
|
||||
statically_indexed_array<long_index_t, KNRepeat> k_page_offsets;
|
||||
statically_indexed_array<long_index_t, VNRepeat> v_page_offsets;
|
||||
|
||||
auto refresh_k_offsets = [&](index_t k_tile_idx) {
|
||||
static_for<0, KNRepeat, 1>{}([&](auto i) {
|
||||
@@ -438,7 +452,8 @@ struct UnifiedAttentionPipeline
|
||||
const index_t phys_page =
|
||||
block_tables_ptr_[block_table_offset + logical_page];
|
||||
k_page_offsets(i) =
|
||||
(phys_page * page_size + within_page) * k_row_stride;
|
||||
(static_cast<long_index_t>(phys_page) * page_size + within_page) *
|
||||
k_row_stride;
|
||||
});
|
||||
};
|
||||
auto refresh_v_offsets = [&](index_t v_tile_idx) {
|
||||
@@ -451,7 +466,8 @@ struct UnifiedAttentionPipeline
|
||||
const index_t phys_page =
|
||||
block_tables_ptr_[block_table_offset + logical_page];
|
||||
v_page_offsets(i) =
|
||||
(phys_page * page_size + within_page) * v_row_stride;
|
||||
(static_cast<long_index_t>(phys_page) * page_size + within_page) *
|
||||
v_row_stride;
|
||||
});
|
||||
};
|
||||
|
||||
@@ -575,15 +591,46 @@ struct UnifiedAttentionPipeline
|
||||
// Pass-2: page indirection lives in page_offsets, not in the SRD. We
|
||||
// refresh the per-iter offsets table and push it to the window via
|
||||
// update_page_idx(); the SRD itself stays put (no init_raw per iter).
|
||||
//
|
||||
// Two load paths, dispatched on the runtime overflow flag:
|
||||
// - false: `async_load_tile_raw` → `buffer_load_dword_lds` with a
|
||||
// wave-uniform 4 GB-capped SRD. Faster, but per-lane voffsets
|
||||
// are int32 so the path is only correct while
|
||||
// `num_blocks * page_size * row_stride * sizeof(T) ≤ INT32_MAX`.
|
||||
// - true: `async_load_tile_raw_long` → `global_load_lds_dwordx*`
|
||||
// with per-lane 64-bit base pointers, lifting the 4 GB limit
|
||||
// at the cost of lower throughput.
|
||||
// The branch is on a wave-uniform value, so no execution divergence.
|
||||
//
|
||||
// For diagnostic purposes: the wave's N-position span within a
|
||||
// single buffer_load instruction is `(LaneGroups-1)*NumWarps + 1`.
|
||||
// When that's > the minimum page_size (≈16) the K-tile distribution
|
||||
// touches multiple pages per issue, so the small-cache buffer_load
|
||||
// path still works (per-lane voffsets fit while the cache ≤ 4 GB)
|
||||
// but the per-issue SRD-rebase optimization (not implemented today)
|
||||
// wouldn't be applicable — only `global_load_lds` works once the
|
||||
// cache exceeds 4 GB.
|
||||
constexpr index_t KWaveSpanInN =
|
||||
(KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] - 1) *
|
||||
KDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}] +
|
||||
1;
|
||||
(void)KWaveSpanInN; // currently informational only
|
||||
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
else
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
k_block_idx++;
|
||||
refresh_k_offsets(k_block_idx);
|
||||
k_dram_window.update_page_idx(k_page_offsets);
|
||||
};
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
if(cache_ptr_int32_overflow_possible)
|
||||
async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
else
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
v_block_idx++;
|
||||
refresh_v_offsets(v_block_idx);
|
||||
v_dram_window.update_page_idx(v_page_offsets);
|
||||
@@ -1269,7 +1316,9 @@ struct UnifiedAttentionPipeline
|
||||
long_index_t v_row_stride = 0,
|
||||
// Forwards to the full-args operator() so callers can plumb in a
|
||||
// runtime kBlockQ. See the documentation on that overload.
|
||||
const index_t num_queries_per_kv = 0) const
|
||||
const index_t num_queries_per_kv = 0,
|
||||
// See the doc on the full-args operator().
|
||||
const bool cache_ptr_int32_overflow_possible = false) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -1292,7 +1341,8 @@ struct UnifiedAttentionPipeline
|
||||
smem_ptr,
|
||||
k_row_stride,
|
||||
v_row_stride,
|
||||
num_queries_per_kv);
|
||||
num_queries_per_kv,
|
||||
cache_ptr_int32_overflow_possible);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user