mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
CK-UA: enable Tier-2 LDS page-table cache on decode + fix split-KV bulk-load OOB
Three coupled changes to the Tier-2 LDS-resident page-table cache:
1. Drop the `kBlockSize >= 8 * warp_size` gate from both the runtime
kScalarPromote{K,V}PageIdx predicates and the static
GetPageTableLdsBytes() allocator. The original conservative gate
excluded TinyDecode (kBlockSize == 64); the trade-off has since
flipped now that bf16 m16 doubles the per-tile iter count (and
thus the per-tile page-table refresh count) via the halved kBlockN
change. Enabling Tier-2 on TinyDecode eliminates the per-iter
`s_waitcnt vmcnt(0)` drains that the per-lane block_tables_ptr_
vector loads were forcing.
2. Fix a silent corruption bug in GetPageTableLdsBytes(): the LDS-
allocation gate carried the old hedge while operator()'s runtime
gate had already dropped it, so on TinyDecode the bulk-load
path wrote into LDS regions belonging to the K/V double buffers
above it. Both gates now share the same constexpr predicate.
3. Split-KV bulk-load correction. refresh_*_offsets indexes
block_tables_lds by absolute page index (= block_table_offset-
relative), so on splits 1+ where the CTA's split_token_offset > 0
the original bulk load only covered pages
[0, num_pages_for_split) and read OOB. We now load
[block_table_offset, block_table_offset + split_end_page) to
cover every absolute page the CTA can index.
Also: add explicit `s_waitcnt_lgkmcnt<0>()` after the bulk-load. On
multi-warp tiers the s_barrier carries the LDS-write drain
implicitly; on single-warp TinyDecode LLVM elides s_barrier entirely
and the refresh path reads stale LDS without the explicit drain.
Validated: correctness sweep across bf16/fp16/fp8 × {decode, prefill}
× b in {1,32,128,256}, sk up to 128k. Decode perf: ~1.18x geomean vs
Triton on long-context d=128 GQA8 (was 1.5x+ pre-fix).
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -166,11 +166,14 @@ struct UnifiedAttentionPipeline
|
||||
VDist::DstrEncode::hs_lengthss_[ck_tile::number<0>{}][ck_tile::number<2>{}];
|
||||
constexpr ck_tile::index_t kPageSizeCap =
|
||||
kHasCePageSize ? kPageSize : ck_tile::index_t{16};
|
||||
// Gate kept in lock-step with kScalarPromote{K,V}PageIdx in
|
||||
// operator(): both decide whether a given kernel instance needs the
|
||||
// Tier-2 LDS-resident page-table cache, and any divergence means the
|
||||
// runtime path writes/reads at an offset for which no LDS was
|
||||
// reserved (silently corrupting the K/V double-buffers above it).
|
||||
constexpr bool kHasTier0K =
|
||||
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
|
||||
(KNRepeat >= 2) && (KY0_step_N <= kPageSizeCap);
|
||||
constexpr bool kHasTier0V =
|
||||
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
|
||||
(VNRepeat >= 2) && (VY0_step_N <= kPageSizeCap);
|
||||
if constexpr (kHasTier0K || kHasTier0V)
|
||||
return kPageTableLdsEntries * sizeof(ck_tile::index_t);
|
||||
@@ -679,13 +682,19 @@ struct UnifiedAttentionPipeline
|
||||
// (newly ON; small win)
|
||||
constexpr index_t kKPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
|
||||
constexpr index_t kVPageSizeCap = kHasCePageSize ? kPageSize : index_t{16};
|
||||
// EXPERIMENT 2026-05: relax the 8-warp gate. The original measurement
|
||||
// showed 1-4 warp decode regressed 3-8%, but that was *before* the
|
||||
// halved-kBlockN bf16 change which (a) doubles the iter count, hence
|
||||
// doubles the per-tile page-table refresh count, and (b) lifts decode
|
||||
// occupancy from 1 CTA/CU to 3 CTAs/CU, multiplying the per-CU
|
||||
// contention on the per-lane block_tables_ptr_ vector loads. Both
|
||||
// shift the trade-off in favour of the Tier-2 LDS-cached path.
|
||||
constexpr bool kScalarPromoteKPageIdx =
|
||||
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
|
||||
(KNRepeat >= 2) && (KY0_step_N <= kKPageSizeCap);
|
||||
constexpr bool kScalarPromoteVPageIdx =
|
||||
(Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) &&
|
||||
(VNRepeat >= 2) && (VY0_step_N <= kVPageSizeCap);
|
||||
|
||||
|
||||
// Tier 2 — LDS-resident page-table cache.
|
||||
//
|
||||
// After Tier 0 the per-K/V-tile cost of resolving phys_page is a
|
||||
@@ -721,22 +730,40 @@ struct UnifiedAttentionPipeline
|
||||
auto block_tables_lds = reinterpret_cast<int32_t*>(
|
||||
static_cast<char*>(smem_ptr) + kPageTableLdsOffset);
|
||||
|
||||
// Split-KV correction: refresh_*_offsets indexes block_tables_lds by
|
||||
// i_base_page = (split_token_offset + …) / page_size, which is the
|
||||
// *absolute* page index within the batch (NOT relative to this split).
|
||||
// For prefill split_token_offset == 0 so the absolute and relative
|
||||
// indices coincide; for split-KV decode (i_total_loops starts at
|
||||
// num_blocks_start > 0 on splits 1+), they diverge and the original
|
||||
// load that only touched pages [0, num_pages_for_split) read OOB.
|
||||
// We bulk-load pages [block_table_offset, block_table_offset +
|
||||
// split_end_page) so every absolute page index this CTA can produce
|
||||
// is covered, at the cost of a tiny (one-shot, kernel-entry) bulk
|
||||
// load for the early portion of the batch we skip past on splits 1+.
|
||||
const index_t split_end_page = static_cast<index_t>(
|
||||
(static_cast<long_index_t>(num_total_loop) * kPageBlockSize + page_size - 1) /
|
||||
page_size);
|
||||
if constexpr (kUsePageTableLds)
|
||||
{
|
||||
// Number of page-table entries this CTA touches. block_tables_lds
|
||||
// is indexed by `logical_page` (= block_table_offset-relative),
|
||||
// matching the index used in refresh_*_offsets below.
|
||||
const long_index_t end_token =
|
||||
static_cast<long_index_t>(num_total_loop) * kPageBlockSize;
|
||||
const index_t num_pages_for_cta =
|
||||
static_cast<index_t>((end_token + page_size - 1) / page_size);
|
||||
assert(num_pages_for_cta <= kPageTableLdsEntries);
|
||||
assert(split_end_page <= kPageTableLdsEntries);
|
||||
|
||||
const index_t tid = get_thread_local_1d_id();
|
||||
for (index_t i = tid; i < num_pages_for_cta; i += Problem::kBlockSize)
|
||||
for (index_t i = tid; i < split_end_page; i += Problem::kBlockSize)
|
||||
{
|
||||
block_tables_lds[i] = block_tables_ptr_[block_table_offset + i];
|
||||
}
|
||||
// Each thread writes a strided subset of block_tables_lds[] and
|
||||
// subsequent refresh_*_offsets reads at i_base_page may be served
|
||||
// by a *different* lane's write (cross-lane LDS access). The
|
||||
// s_barrier below handles cross-wave ordering, but on single-warp
|
||||
// CTAs (TinyDecode, kBlockSize == warp_size) LLVM elides s_barrier
|
||||
// entirely — and with it the implicit lgkmcnt(0) drain that
|
||||
// commits this wave's ds_writes. Without an explicit drain the
|
||||
// refresh path then reads stale LDS. Adding `s_waitcnt lgkmcnt(0)`
|
||||
// is a no-op on the multi-warp tiers (the s_barrier carries it
|
||||
// implicitly) and load-bearing for single-warp tiers.
|
||||
s_waitcnt_lgkmcnt<0>();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user