diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index a47516a158..82aea3e49d 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -166,11 +166,14 @@ struct UnifiedAttentionPipeline VDist::DstrEncode::hs_lengthss_[ck_tile::number<0>{}][ck_tile::number<2>{}]; constexpr ck_tile::index_t kPageSizeCap = kHasCePageSize ? kPageSize : ck_tile::index_t{16}; + // Gate kept in lock-step with kScalarPromote{K,V}PageIdx in + // operator(): both decide whether a given kernel instance needs the + // Tier-2 LDS-resident page-table cache, and any divergence means the + // runtime path writes/reads at an offset for which no LDS was + // reserved (silently corrupting the K/V double-buffers above it). constexpr bool kHasTier0K = - (Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) && (KNRepeat >= 2) && (KY0_step_N <= kPageSizeCap); constexpr bool kHasTier0V = - (Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) && (VNRepeat >= 2) && (VY0_step_N <= kPageSizeCap); if constexpr (kHasTier0K || kHasTier0V) return kPageTableLdsEntries * sizeof(ck_tile::index_t); @@ -679,13 +682,19 @@ struct UnifiedAttentionPipeline // (newly ON; small win) constexpr index_t kKPageSizeCap = kHasCePageSize ? kPageSize : index_t{16}; constexpr index_t kVPageSizeCap = kHasCePageSize ? kPageSize : index_t{16}; + // EXPERIMENT 2026-05: relax the 8-warp gate. The original measurement + // showed 1-4 warp decode regressed 3-8%, but that was *before* the + // halved-kBlockN bf16 change which (a) doubles the iter count, hence + // doubles the per-tile page-table refresh count, and (b) lifts decode + // occupancy from 1 CTA/CU to 3 CTAs/CU, multiplying the per-CU + // contention on the per-lane block_tables_ptr_ vector loads. Both + // shift the trade-off in favour of the Tier-2 LDS-cached path. constexpr bool kScalarPromoteKPageIdx = - (Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) && (KNRepeat >= 2) && (KY0_step_N <= kKPageSizeCap); constexpr bool kScalarPromoteVPageIdx = - (Problem::kBlockSize >= 8 * ck_tile::get_warp_size()) && (VNRepeat >= 2) && (VY0_step_N <= kVPageSizeCap); + // Tier 2 — LDS-resident page-table cache. // // After Tier 0 the per-K/V-tile cost of resolving phys_page is a @@ -721,22 +730,40 @@ struct UnifiedAttentionPipeline auto block_tables_lds = reinterpret_cast( static_cast(smem_ptr) + kPageTableLdsOffset); + // Split-KV correction: refresh_*_offsets indexes block_tables_lds by + // i_base_page = (split_token_offset + …) / page_size, which is the + // *absolute* page index within the batch (NOT relative to this split). + // For prefill split_token_offset == 0 so the absolute and relative + // indices coincide; for split-KV decode (i_total_loops starts at + // num_blocks_start > 0 on splits 1+), they diverge and the original + // load that only touched pages [0, num_pages_for_split) read OOB. + // We bulk-load pages [block_table_offset, block_table_offset + + // split_end_page) so every absolute page index this CTA can produce + // is covered, at the cost of a tiny (one-shot, kernel-entry) bulk + // load for the early portion of the batch we skip past on splits 1+. + const index_t split_end_page = static_cast( + (static_cast(num_total_loop) * kPageBlockSize + page_size - 1) / + page_size); if constexpr (kUsePageTableLds) { - // Number of page-table entries this CTA touches. block_tables_lds - // is indexed by `logical_page` (= block_table_offset-relative), - // matching the index used in refresh_*_offsets below. - const long_index_t end_token = - static_cast(num_total_loop) * kPageBlockSize; - const index_t num_pages_for_cta = - static_cast((end_token + page_size - 1) / page_size); - assert(num_pages_for_cta <= kPageTableLdsEntries); + assert(split_end_page <= kPageTableLdsEntries); const index_t tid = get_thread_local_1d_id(); - for (index_t i = tid; i < num_pages_for_cta; i += Problem::kBlockSize) + for (index_t i = tid; i < split_end_page; i += Problem::kBlockSize) { block_tables_lds[i] = block_tables_ptr_[block_table_offset + i]; } + // Each thread writes a strided subset of block_tables_lds[] and + // subsequent refresh_*_offsets reads at i_base_page may be served + // by a *different* lane's write (cross-lane LDS access). The + // s_barrier below handles cross-wave ordering, but on single-warp + // CTAs (TinyDecode, kBlockSize == warp_size) LLVM elides s_barrier + // entirely — and with it the implicit lgkmcnt(0) drain that + // commits this wave's ds_writes. Without an explicit drain the + // refresh path then reads stale LDS. Adding `s_waitcnt lgkmcnt(0)` + // is a no-op on the multi-warp tiers (the s_barrier carries it + // implicitly) and load-bearing for single-warp tiers. + s_waitcnt_lgkmcnt<0>(); __builtin_amdgcn_s_barrier(); }