CK-UA: enable Tier-2 LDS page-table cache on decode + fix split-KV bulk-load OOB

Three coupled changes to the Tier-2 LDS-resident page-table cache:

1. Drop the `kBlockSize >= 8 * warp_size` gate from both the runtime
   kScalarPromote{K,V}PageIdx predicates and the static
   GetPageTableLdsBytes() allocator. The original conservative gate
   excluded TinyDecode (kBlockSize == 64); the trade-off has since
   flipped now that bf16 m16 doubles the per-tile iter count (and
   thus the per-tile page-table refresh count) via the halved kBlockN
   change. Enabling Tier-2 on TinyDecode eliminates the per-iter
   `s_waitcnt vmcnt(0)` drains that the per-lane block_tables_ptr_
   vector loads were forcing.

2. Fix a silent corruption bug in GetPageTableLdsBytes(): the LDS-
   allocation gate carried the old hedge while operator()'s runtime
   gate had already dropped it, so on TinyDecode the bulk-load
   path wrote into LDS regions belonging to the K/V double buffers
   above it. Both gates now share the same constexpr predicate.

3. Split-KV bulk-load correction. refresh_*_offsets indexes
   block_tables_lds by absolute page index (= block_table_offset-
   relative), so on splits 1+ where the CTA's split_token_offset > 0
   the original bulk load only covered pages
   [0, num_pages_for_split) and read OOB. We now load
   [block_table_offset, block_table_offset + split_end_page) to
   cover every absolute page the CTA can index.

Also: add explicit `s_waitcnt_lgkmcnt<0>()` after the bulk-load. On
multi-warp tiers the s_barrier carries the LDS-write drain
implicitly; on single-warp TinyDecode LLVM elides s_barrier entirely
and the refresh path reads stale LDS without the explicit drain.

Validated: correctness sweep across bf16/fp16/fp8 × {decode, prefill}
× b in {1,32,128,256}, sk up to 128k. Decode perf: ~1.18x geomean vs
Triton on long-context d=128 GQA8 (was 1.5x+ pre-fix).

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-26 08:21:10 +00:00
parent 310efc556f
commit badc807025

View File

@@ -166,11 +166,14 @@ struct UnifiedAttentionPipeline
VDist::DstrEncode::hs_lengthss_[ck_tile::number<0>{}][ck_tile::number<2>{}];
constexpr ck_tile::index_t kPageSizeCap =
kHasCePageSize ? kPageSize : ck_tile::index_t{16};
// Gate kept in lock-step with kScalarPromote{K,V}PageIdx in
// operator(): both decide whether a given kernel instance needs the
// Tier-2 LDS-resident page-table cache, and any divergence means the
// runtime path writes/reads at an offset for which no LDS was
// reserved (silently corrupting the K/V double-buffers above it).
constexpr bool kHasTier0K =
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
(KNRepeat >= 2) && (KY0_step_N <= kPageSizeCap);
constexpr bool kHasTier0V =
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
(VNRepeat >= 2) && (VY0_step_N <= kPageSizeCap);
if constexpr (kHasTier0K || kHasTier0V)
return kPageTableLdsEntries * sizeof(ck_tile::index_t);
@@ -679,13 +682,19 @@ struct UnifiedAttentionPipeline
// (newly ON; small win)
constexpr index_t kKPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
constexpr index_t kVPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
// EXPERIMENT 2026-05: relax the 8-warp gate. The original measurement
// showed 1-4 warp decode regressed 3-8%, but that was *before* the
// halved-kBlockN bf16 change which (a) doubles the iter count, hence
// doubles the per-tile page-table refresh count, and (b) lifts decode
// occupancy from 1 CTA/CU to 3 CTAs/CU, multiplying the per-CU
// contention on the per-lane block_tables_ptr_ vector loads. Both
// shift the trade-off in favour of the Tier-2 LDS-cached path.
constexpr bool kScalarPromoteKPageIdx =
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
(KNRepeat >= 2) && (KY0_step_N <= kKPageSizeCap);
constexpr bool kScalarPromoteVPageIdx =
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
(VNRepeat >= 2) && (VY0_step_N <= kVPageSizeCap);
// Tier 2 — LDS-resident page-table cache.
//
// After Tier 0 the per-K/V-tile cost of resolving phys_page is a
@@ -721,22 +730,40 @@ struct UnifiedAttentionPipeline
auto block_tables_lds = reinterpret_cast<int32_t*>(
static_cast<char*>(smem_ptr) + kPageTableLdsOffset);
// Split-KV correction: refresh_*_offsets indexes block_tables_lds by
// i_base_page = (split_token_offset + …) / page_size, which is the
// *absolute* page index within the batch (NOT relative to this split).
// For prefill split_token_offset == 0 so the absolute and relative
// indices coincide; for split-KV decode (i_total_loops starts at
// num_blocks_start > 0 on splits 1+), they diverge and the original
// load that only touched pages [0, num_pages_for_split) read OOB.
// We bulk-load pages [block_table_offset, block_table_offset +
// split_end_page) so every absolute page index this CTA can produce
// is covered, at the cost of a tiny (one-shot, kernel-entry) bulk
// load for the early portion of the batch we skip past on splits 1+.
const index_t split_end_page = static_cast<index_t>(
(static_cast<long_index_t>(num_total_loop) * kPageBlockSize + page_size - 1) /
page_size);
if constexpr (kUsePageTableLds)
{
// Number of page-table entries this CTA touches. block_tables_lds
// is indexed by `logical_page` (= block_table_offset-relative),
// matching the index used in refresh_*_offsets below.
const long_index_t end_token =
static_cast<long_index_t>(num_total_loop) * kPageBlockSize;
const index_t num_pages_for_cta =
static_cast<index_t>((end_token + page_size - 1) / page_size);
assert(num_pages_for_cta <= kPageTableLdsEntries);
assert(split_end_page <= kPageTableLdsEntries);
const index_t tid = get_thread_local_1d_id();
for (index_t i = tid; i < num_pages_for_cta; i += Problem::kBlockSize)
for (index_t i = tid; i < split_end_page; i += Problem::kBlockSize)
{
block_tables_lds[i] = block_tables_ptr_[block_table_offset + i];
}
// Each thread writes a strided subset of block_tables_lds[] and
// subsequent refresh_*_offsets reads at i_base_page may be served
// by a *different* lane's write (cross-lane LDS access). The
// s_barrier below handles cross-wave ordering, but on single-warp
// CTAs (TinyDecode, kBlockSize == warp_size) LLVM elides s_barrier
// entirely — and with it the implicit lgkmcnt(0) drain that
// commits this wave's ds_writes. Without an explicit drain the
// refresh path then reads stale LDS. Adding `s_waitcnt lgkmcnt(0)`
// is a no-op on the multi-warp tiers (the s_barrier carries it
// implicitly) and load-bearing for single-warp tiers.
s_waitcnt_lgkmcnt<0>();
__builtin_amdgcn_s_barrier();
}