From c11722bf3ea6934fe1df9a352e478864960844d2 Mon Sep 17 00:00:00 2001 From: juuso-oskari Date: Tue, 9 Jun 2026 14:35:00 +0000 Subject: [PATCH] CK-UA: decouple K/V DRAM loads per warp-group + early V read (FA4 fp8) Split the cooperative K/V cache loads across the two FA4 warp groups so each group owns exactly one tile's DRAM load and address arithmetic: WG0 loads V, WG1 loads K, and both read from the shared LDS buffers. - kFA4WG0LoadsV / kFA4WG1LoadsK policy flags + GetVLoadNumWarps / GetKLoadNumWarps: the owning group's 4 waves alone fill the tile via 4-warp descriptors; the partner skips the load and reads from LDS. - High-warp-group support for the raw async path: the raw store bakes the absolute warp id into the LDS M0, so WG1 (waves 4-7) needs a base shift (GetKStoreWarpShift / WarpIdShift in MakeKLdsStoreBlockDescriptor) to map back to the 4-warp layout, plus WG-relative (warp % NumWarps) page offsets so the gather token positions are correct. - Stage B: move each tile's V LDS read into the PRECEDING softmax phase so the read latency hides under softmax VALU. Safe because V is now single- group-owned; uses drain-before-barrier (vmcnt<0> then s_barrier) so all 4 cooperating writer waves' slices are published before the read. - Gate per-tile offset refresh per warp-group (WG0 refreshes V, WG1 K), so each wave fetches a block-table page index for one tile instead of both; loop counters stay uniform. Validated 0% mismatch vs GPU reference, causal + non-causal, sq 256..8192. Net latency vs the cooperative baseline: causal ~-3-4.6%, non-causal ~-2-4.7% across sq 2048..16384 (d128 fp8). Co-authored-by: Cursor --- .../pipeline/unified_attention_pipeline.hpp | 206 ++++++++++++------ ...fied_attention_pipeline_default_policy.hpp | 133 +++++++++-- 2 files changed, 248 insertions(+), 91 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 279b1ed329..8f1ddf84d6 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -488,6 +488,19 @@ struct UnifiedAttentionPipeline const index_t warp_group_id = get_warp_id() / 4; + // FA4 "WG0 loads V": warp group 0's 4 waves load the FULL V tile into + // the shared V LDS buffer (V descriptors use VLoadNumWarps == 4 waves); + // warp group 1 skips the V DRAM load and relies on the inter-phase + // barrier for residency. v_load_active gates the async V load issue. + constexpr index_t VLoadNumWarps = Policy::template GetVLoadNumWarps(); + constexpr index_t KLoadNumWarps = Policy::template GetKLoadNumWarps(); + constexpr index_t NumWarpGroups_ = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; + const bool v_load_active = + (!Policy::kFA4WG0LoadsV) || (NumWarpGroups_ != 2) || (warp_group_id == 0); + // Symmetric: warp group 1 alone loads K (WG0 reads from shared LDS). + const bool k_load_active = + (!Policy::kFA4WG1LoadsK) || (NumWarpGroups_ != 2) || (warp_group_id == 1); + // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); @@ -531,34 +544,41 @@ struct UnifiedAttentionPipeline const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + constexpr index_t KStoreWarpShift = Policy::template GetKStoreWarpShift(); auto k_lds_window_store = generate_tuple( [&](auto i_buf) { return make_lds_tile_window( - smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); + smem_ptr, + Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); }, number<2>{}); auto v_lds_window_store = generate_tuple( [&](auto i_buf) { return make_lds_tile_window( - smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + smem_ptr, + Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); }, number<2>{}); - statically_indexed_array( - nullptr, - Policy::template MakeKLdsLoadBlockDescriptor()), - Policy::template MakeKRegTileDistribution())), - 2> + statically_indexed_array< + decltype(make_tile_window( + make_lds_tile_window( + nullptr, + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution())), + 2> k_lds_window_load; - statically_indexed_array( - nullptr, - Policy::template MakeVLdsLoadBlockDescriptor()), - Policy::template MakeVRegTileDistribution())), - 2> + statically_indexed_array< + decltype(make_tile_window( + make_lds_tile_window( + nullptr, + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution())), + 2> v_lds_window_load; decltype(make_static_distributed_tensor( @@ -619,7 +639,7 @@ struct UnifiedAttentionPipeline k_lds_window_load(idx) = make_tile_window( make_lds_tile_window( static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), - Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKLdsLoadBlockDescriptor()), Policy::template MakeKRegTileDistribution()); }); @@ -628,7 +648,8 @@ struct UnifiedAttentionPipeline make_tile_window(make_lds_tile_window( static_cast(smem_ptr) + (idx + 2) * Policy::template GetSmemSizeKV(), - Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVLdsLoadBlockDescriptor()), Policy::template MakeVRegTileDistribution()); }); @@ -714,8 +735,8 @@ struct UnifiedAttentionPipeline // a large negative relative offset that the HW OOB check clamps to 0). // A robust fix would either plumb long_index_t through the gather load // path or compute a per-batch min-page shift in a pre-pass. - const auto k_dist = Policy::template MakeKDramTileDistribution(); - const auto v_dist = Policy::template MakeVDramTileDistribution(); + const auto k_dist = Policy::template MakeKDramTileDistribution(); + const auto v_dist = Policy::template MakeVDramTileDistribution(); using KDstrType = decltype(k_dist); using VDstrType = decltype(v_dist); constexpr index_t KNRepeat = @@ -729,8 +750,19 @@ struct UnifiedAttentionPipeline VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<1>{}] * VDstrType::DstrEncode::hs_lengthss_[number<0>{}][number<2>{}]; - const auto k_thread_coord = k_dist.calculate_index(); - const auto v_thread_coord = v_dist.calculate_index(); + // WG-relative warp index for the gather page-offset computation. When a + // single warp group loads a tile (V by WG0 / K by WG1), only that + // group's waves issue the load, and their absolute warp ids must be + // folded back into [0, NumWarps) so k_thread_n_pos / v_thread_n_pos (the + // per-wave token position baked into page_idx) match the group-relative + // distribution. For the cooperative case NumWarps == full block, so the + // modulo is the identity. The scatter-gather's own get_partition_index + // use is harmless here: the gather (token) dim is zeroed and replaced by + // page_idx, and the remaining (head-dim) coordinate is lane-based. + const auto k_part = ck_tile::array{get_warp_id() % KLoadNumWarps, get_lane_id()}; + const auto v_part = ck_tile::array{get_warp_id() % VLoadNumWarps, get_lane_id()}; + const auto k_thread_coord = k_dist.calculate_index(k_part); + const auto v_thread_coord = v_dist.calculate_index(v_part); const index_t k_thread_n_pos = k_thread_coord[number<0>{}]; const index_t v_thread_n_pos = v_thread_coord[number<0>{}]; @@ -1456,12 +1488,24 @@ struct UnifiedAttentionPipeline // num_blocks_start. const index_t num_iters_per_split = num_total_loop - num_blocks_start; auto K_mem_load = [&](auto k_lds_write_idx) { - if(cache_ptr_int32_overflow_possible) - async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window); - else - async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + // FA4 "WG1 loads K": only warp group 1's waves issue the K async load + // (its KLoadNumWarps==4 layout fills the full shared K tile). WG0 + // skips it and reads K from shared LDS (barrier-synchronized). + if(k_load_active) + { + if(cache_ptr_int32_overflow_possible) + async_load_tile_raw_long(k_lds_window_store(k_lds_write_idx), k_dram_window); + else + async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + } k_block_idx++; - if(k_block_idx < num_iters_per_split) + // Only the K-loading warp group needs K offsets refreshed: with + // kFA4WG1LoadsK, WG0 never issues a K load, so computing its page + // offsets (incl. the block-table page-index ds_read) is pure waste. + // k_block_idx itself stays uniform across all waves so loop control + // and buffer parity never diverge. Gating here also means each wave + // fetches a page-table index for exactly ONE tile (K *or* V), not both. + if(k_load_active && k_block_idx < num_iters_per_split) { refresh_k_offsets(k_block_idx); if constexpr(kRebaseKSrd) @@ -1474,12 +1518,22 @@ struct UnifiedAttentionPipeline }; auto V_mem_load = [&](auto v_lds_write_idx) { - if(cache_ptr_int32_overflow_possible) - async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window); - else - async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + // FA4 "WG0 loads V": only warp group 0's waves issue the V async + // load (its VLoadNumWarps==4 layout fills the full shared V tile). + // WG1 skips the load; bookkeeping (v_block_idx / offsets) stays + // uniform across all waves so the loop's scalar state never diverges. + if(v_load_active) + { + if(cache_ptr_int32_overflow_possible) + async_load_tile_raw_long(v_lds_window_store(v_lds_write_idx), v_dram_window); + else + async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + } v_block_idx++; - if(v_block_idx < num_iters_per_split) + // Symmetric to K: only the V-loading warp group (WG0) refreshes V + // offsets; WG1 skips it (it never issues a V load). v_block_idx stays + // uniform for loop/buffer bookkeeping. + if(v_load_active && v_block_idx < num_iters_per_split) { refresh_v_offsets(v_block_idx); if constexpr(kRebaseVSrd) @@ -2165,41 +2219,23 @@ struct UnifiedAttentionPipeline auto gemm1 = number<1>{}; // MATRIX phase: deferred PV(k-1) then QK(k). Pure matrix pipe. - // Consumes V(1-pi) / K(pi) resident in LDS; the union kv_tile holds - // v_tile for the PV then is overwritten with k_tile for the QK - // (V_lds → gemm1 → K_lds → gemm0 ordering). - // V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi) - // EARLY — before the matrix phase's compute — so its ~LDS-latency - // overlaps the address-calc VALU of prefetch() (WG0) / the barrier - // exit (WG1) instead of being exposed right before the PV MFMA. The - // V buffer pi was populated by a prior prefetch and already waited - // on (vmcnt<0> at phase entry); prefetch writes K-buf[pi]/V-buf[1-pi] - // so there is no aliasing with this read. - // V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi) - // EARLY — before the matrix phase's compute — so its ~LDS-latency - // overlaps the address-calc VALU of prefetch() (WG0) / the barrier - // exit (WG1) instead of being exposed right before the PV MFMA. The - // V buffer pi was populated by a prior prefetch and already waited - // on (vmcnt<0> at phase entry); prefetch writes K-buf[pi]/V-buf[1-pi] - // so there is no aliasing with this read. + // Consumes V(pi) / K(1-pi) resident in LDS; kv_tile holds v_tile for + // the PV and (separately) k_tile for the QK. // - // NOTE: do NOT also hoist K_lds_load here. The QK gemm reads K-buf - // [1-pi], and in the WG1 softmax-first prologue that buffer is not - // yet guaranteed resident this early (its async load completes a - // phase later) — hoisting K corrupts long-context runs. K stays - // issued between the PV and QK MFMAs (its latency hides under PV). - // V-read hoist: issue the PV gemm's LDS→register read (v_rd == pi) - // EARLY — before the matrix phase's compute — so its ~LDS-latency - // overlaps the address-calc VALU of prefetch() (WG0) / the barrier - // exit (WG1) instead of being exposed right before the PV MFMA. + // V-read into SOFTMAX (Stage B): the PV gemm's V tile (v_rd == pi) is + // now read in the *preceding* SOFTMAX phase rather than at the top of + // this MATRIX phase, so its ~LDS latency overlaps the full softmax + // VALU (exp / rowsum / P-cvt) instead of only the prefetch address + // calc. This is safe now that V is loaded by a single warp group + // (kFA4WG0LoadsV): WG0 reads V it loaded itself, so its own vmcnt<0> + // proves residency (no partner dependency); WG1 reads an already- + // barrier-published V buffer. The pre-read lands in v_tile and this + // MATRIX phase consumes it directly (see fa4_softmax / the WG0 prime). // - // NOTE: do NOT also hoist K_lds_load. K/V are loaded COOPERATIVELY - // by both warp groups; a wave's vmcnt only drains its OWN async - // loads, not the partner group's half, so a cooperatively-filled - // buffer is reliably resident only deeper into the phase. The PV - // gemm provides exactly that slack for the K read — moving K ahead - // of PV races the partner's load completion and corrupts long - // contexts. K stays issued between the PV and QK MFMAs. + // NOTE: do NOT hoist K_lds_load the same way. The QK gemm reads K-buf + // [1-pi] which the partner group (WG1) loads; WG0 has no own-vmcnt + // proof of its residency this early, only the barrier deeper in the + // phase. K stays issued between the PV and QK MFMAs (latency under PV). auto fa4_vload = [&](auto pi) { V_lds_load(pi); }; auto fa4_matrix = [&](auto pi) { @@ -2207,14 +2243,15 @@ struct UnifiedAttentionPipeline auto qk_sp = number<1>{} - pi; // QK target slot auto k_rd = number<1>{} - pi; - s_waitcnt_lgkmcnt<0>(); // wait the hoisted fa4_vload(pi) + s_waitcnt_lgkmcnt<0>(); // wait the V pre-read issued in prev SOFTMAX gemm(pv_sp, gemm1); // o_acc += P(pi) @ V(k-1) - // K read into its OWN registers (kv_tile no longer a union), so - // this ds_read executes on the LSU *during* the PV MFMA above - // instead of waiting for it to retire. The sched_barrier pins it - // here (program order) so it is NOT hoisted above the PV gemm -- - // that would race the partner WG's cooperative K load and - // corrupt long contexts (the residency hazard documented above). + // K read into its OWN registers (k_tile no longer aliases v_tile), + // so this ds_read executes on the LSU *during* the PV MFMA above + // rather than waiting for it to retire; the sched_barriers pin it + // here. K is now single-warp-group loaded (kFA4WG1LoadsK) so it is + // resident at the slot-A barrier, but issuing the read AFTER the + // PV gemm call (overlapping the in-flight MFMA) schedules strictly + // better than hoisting it ahead of PV — measured ~3-4% faster. __builtin_amdgcn_sched_barrier(0); K_lds_load(k_rd); // overlaps the PV MFMA (latency hidden) __builtin_amdgcn_sched_barrier(0); @@ -2256,20 +2293,32 @@ struct UnifiedAttentionPipeline if constexpr(cl_p == 0) { // ---- slot A: MATRIX(pi) ‖ (WG1: SOFTMAX) ---- + // V tile (buf pi) was pre-read into v_tile in the previous + // SOFTMAX phase (or the WG0 prime for the first tile). ASM_MARKER("fa4 MATRIX Wave0-3"); s_waitcnt_vmcnt<0>(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); - fa4_vload(pi); // hoisted V read; latency hidden under prefetch prefetch(); fa4_matrix(pi); // ---- slot B: SOFTMAX(pi) ‖ (WG1: MATRIX) ---- + // Pre-read the next MATRIX's V tile (buf 1-pi == the buffer + // this iteration's prefetch just filled), overlapping the + // softmax VALU below; v_tile survives into the next slot-A + // MATRIX (PV consumes it via lgkmcnt<0>). The V buffer is + // filled cooperatively by WG0's 4 waves, so a wave reads + // slices written by its peers: drain the load (vmcnt<0>) and + // then cross the phase barrier so all 4 waves' writes are + // published BEFORE the read (drain-before-barrier; reading + // after only an own-vmcnt races the peers' slices). ASM_MARKER("fa4 SOFTMAX Wave0-3"); + s_waitcnt_vmcnt<0>(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + fa4_vload(number<1>{} - pi); fa4_softmax(pi); if(num_total_loop <= ++i_total_loops) @@ -2281,20 +2330,28 @@ struct UnifiedAttentionPipeline // WG1 is one phase ahead (primed by the FA4 prologue): it // softmaxes the tile it QK'd in its previous MATRIX phase // while WG0 runs the MATRIX of the same tile. + // + // Pre-read this iteration's slot-B MATRIX V tile (buf pi). + // That buffer was filled by WG0 and already drained+published + // by WG0's drain-before-barrier in its prior SOFTMAX slot, so + // the slot-A barrier just crossed guarantees all 4 writer + // waves' slices are visible. The read overlaps the softmax + // VALU below; v_tile survives into the slot-B MATRIX. ASM_MARKER("fa4 SOFTMAX Wave4-7"); s_waitcnt_vmcnt<0>(); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + fa4_vload(pi); prefetch(); fa4_softmax(number<1>{} - pi); // ---- slot B: MATRIX(pi) ‖ (WG0: SOFTMAX) ---- + // v_tile holds buf pi from the slot-A pre-read above. ASM_MARKER("fa4 MATRIX Wave4-7"); __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); - fa4_vload(pi); // hoisted V read (overlaps barrier exit) fa4_matrix(pi); if(num_total_loop <= ++i_total_loops) @@ -2508,6 +2565,11 @@ struct UnifiedAttentionPipeline fmha_alu0(number<0>{}); fmha_alu_D_upd(); fmha_alu1(number<0>{}); // sp(0).p = P(0) + // Prime v_tile for the first MATRIX(0): V buf 0 was loaded by + // WG0 in the pre-stage, so its own vmcnt<0> proves residency. + // (Stage B reads each subsequent tile's V in the prior SOFTMAX.) + s_waitcnt_vmcnt<0>(); + V_lds_load(number<0>{}); while(core_loop_fa4(number<0>{})) ; } diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index 98cad6f23a..beeac94c82 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -141,7 +141,11 @@ struct UnifiedAttentionPipelineDefaultPolicy return 16 / sizeof(VDataType); } - template + // NumWarpsOverride mirrors MakeVDramTileDistribution: the FA4 "WG1 loads K" + // path passes NumThreadPerWarpGroup/WarpSize so warp group 1's waves alone + // tile the full K buffer (the partner group reads it from shared LDS). + template CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution() { using namespace ck_tile; @@ -149,7 +153,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t KVector = GetAlignmentK(); // this is for global load @@ -158,7 +162,8 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector)); + static_cast(kBlockSize); constexpr index_t N0 = NumIssues; constexpr index_t N1 = LaneGroups; @@ -175,7 +180,13 @@ struct UnifiedAttentionPipelineDefaultPolicy sequence<0, 1>>{}); } - template + // NumWarpsOverride lets the FA4 per-warp-group ("private V") path request a + // distribution where only NumWarps waves cooperate on the load (so each + // warp group loads the FULL V tile by itself, into its own LDS buffer, and + // its own vmcnt proves residency without waiting on the partner group). + // Default = the shape's NumWarps (the original block-cooperative load). + template CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() { using namespace ck_tile; @@ -183,7 +194,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64 constexpr index_t KVector = GetAlignmentV(); // this is for global load @@ -193,7 +204,10 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // NumWarps-relative form (NumWarps may be < the full block when the FA4 + // per-warp-group path requests a private-V distribution). + static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector)); + static_cast(kBlockSize); constexpr index_t N0 = NumIssues; // 8 constexpr index_t N1 = LaneGroups; // 2 @@ -378,7 +392,19 @@ struct UnifiedAttentionPipelineDefaultPolicy static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords - template + // WarpIdShift handles a sub-block load issued by a NON-zero warp group via + // the raw async path. The raw store derives its LDS offset as + // M0 = base + size_per_wave * get_warp_id() (ABSOLUTE warp id 0..7) + // so a NumWarps-wide (e.g. 4-wave) layout only tiles correctly for warp ids + // 0..NumWarps-1. When warp group g (>0) alone fills the tile, its waves have + // absolute ids [g*NumWarps, (g+1)*NumWarps); shifting the descriptor base by + // -WarpIdShift*size_per_wave (WarpIdShift = g*NumWarps) maps them back to + // effective ids 0..NumWarps-1, i.e. the exact physical layout a warp-group-0 + // load would produce -- so the (unshifted) read descriptor reads it directly. + template CK_TILE_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) { @@ -388,7 +414,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds @@ -405,7 +431,8 @@ struct UnifiedAttentionPipelineDefaultPolicy WarpSize / LanesPerK; // how many groups (within a wave), they may load different N, but same K constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector)); + static_cast(kBlockSize); constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( make_tuple(number{}, // n0 @@ -418,7 +445,8 @@ struct UnifiedAttentionPipelineDefaultPolicy number{}, number{}, number<1>{}), - number()>{}, + number() - + WarpIdShift*(WarpSize * KVector + kPad)>{}, number{}, number<1>{}); @@ -436,7 +464,8 @@ struct UnifiedAttentionPipelineDefaultPolicy return k_lds_block_desc_issues_warps_lanes; } - template + template CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() { using namespace ck_tile; @@ -445,7 +474,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t KPack = GetSmemKPackK(); // this is for lds @@ -458,7 +487,8 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector)); + static_cast(kBlockSize); constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(make_tuple(number{}, // n0 @@ -553,7 +583,12 @@ struct UnifiedAttentionPipelineDefaultPolicy return max(max(SingleKSize, SingleVSize), VLoadDescSize); } - template + // NumWarpsOverride mirrors MakeVDramTileDistribution: the FA4 "WG0 loads V" + // path passes NumThreadPerWarpGroup/WarpSize (== 4) so warp group 0's waves + // alone tile the full V buffer. Default = the shape's NumWarps (cooperative). + template CK_TILE_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) { @@ -563,7 +598,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds @@ -580,7 +615,8 @@ struct UnifiedAttentionPipelineDefaultPolicy WarpSize / LanesPerK; // how many groups (within a wave), they may load different N, but same K constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector)); + static_cast(kBlockSize); constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( make_tuple(number{}, // n0 @@ -611,7 +647,8 @@ struct UnifiedAttentionPipelineDefaultPolicy return v_lds_block_desc_issues_warps_lanes; } - template + template CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor() { using namespace ck_tile; @@ -620,7 +657,7 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t NumWarps = NumWarpsOverride; constexpr index_t WarpSize = ck_tile::get_warp_size(); constexpr index_t KPack = GetSmemVPackK(); // this is for lds @@ -633,7 +670,8 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (NumWarps * WarpSize * KVector)); + static_cast(kBlockSize); constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(make_tuple(number{}, // n0 @@ -688,6 +726,63 @@ struct UnifiedAttentionPipelineDefaultPolicy return kv_element_space_size_in_bytes; } + // FA4 "WG0 loads V" prototype: when the block runs as two warp groups, have + // ONLY warp group 0 (waves 0-3) load the full V tile into the shared V LDS + // buffer (V's DRAM dist + LDS descriptors use NumThreadPerWarpGroup/WarpSize + // == 4 waves so WG0 alone fills the tile). WG1 skips the V DRAM load + // entirely. No 2x DRAM, no extra LDS (V stays a shared 2-buffer). This + // decouples V's residency from the partner group's cooperative-load shard + // (WG0's own vmcnt proves the load) so the V LDS read can later move into + // the SOFTMAX phase. K stays block-cooperative across all 8 waves. + // Toggle to false to restore the block-cooperative (8-wave) V load. + static constexpr bool kFA4WG0LoadsV = true; + + // Symmetric K decoupling: warp group 1 (waves 4-7) alone loads the full K + // tile into the shared K LDS buffer; warp group 0 reads it from shared LDS. + // Together with kFA4WG0LoadsV this balances DRAM-load work (WG0->V, WG1->K) + // and lets each group issue only one tile's load/address instructions. + static constexpr bool kFA4WG1LoadsK = true; + + // Number of waves that cooperate on a V DRAM->LDS load. For the 2-warp-group + // FA4 path with kFA4WG0LoadsV, this is one warp group's waves (so WG0 alone + // fills the tile); otherwise it's the full block (original cooperative load). + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetVLoadNumWarps() + { + constexpr ck_tile::index_t NumWarpGroups = + Problem::kBlockSize / NumThreadPerWarpGroup; + if constexpr(kFA4WG0LoadsV && NumWarpGroups == 2) + return NumThreadPerWarpGroup / ck_tile::get_warp_size(); + else + return Problem::UnifiedAttentionShape::NumWarps; + } + + // K analogue of GetVLoadNumWarps (warp group 1 alone fills the K tile). + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetKLoadNumWarps() + { + constexpr ck_tile::index_t NumWarpGroups = + Problem::kBlockSize / NumThreadPerWarpGroup; + if constexpr(kFA4WG1LoadsK && NumWarpGroups == 2) + return NumThreadPerWarpGroup / ck_tile::get_warp_size(); + else + return Problem::UnifiedAttentionShape::NumWarps; + } + + // Raw-async warp-id shift for the K store (see MakeKLdsStoreBlockDescriptor): + // K is loaded by warp group 1, whose absolute warp ids start at one warp + // group's worth of waves, so the store base must shift by that many waves. + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetKStoreWarpShift() + { + constexpr ck_tile::index_t NumWarpGroups = + Problem::kBlockSize / NumThreadPerWarpGroup; + if constexpr(kFA4WG1LoadsK && NumWarpGroups == 2) + return NumThreadPerWarpGroup / ck_tile::get_warp_size(); // WG1's first abs warp id + else + return 0; + } + template CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() {