From 879d50836e72d54f0f101f988d02953b35ef330b Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Sun, 17 May 2026 02:35:07 -0400 Subject: [PATCH] cleanup(sparse_attn): R-tag rename + clang-format sweep MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Strip internal R-tag / phase labels (R20, R21A/B, Round 8/13f, Track F, B2.v3, Phase 1/2/3) from comments — replace with descriptive names so future readers don't need the change-log. Reflow long signature in fmha_fwd_trek.hpp. Co-Authored-By: Claude Opus 4 --- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 4 +- .../pipeline/sparge_blockmap_pipeline.hpp | 62 +++++++++---------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 62d40ffbe0..384f7bb56d 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -283,7 +283,9 @@ float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); template void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args); -void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, + fmha_jenga_fwd_args, + const ck_tile::stream_config&); // VSA uses the same traits structure as Jenga; aliases for clarity template CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread], float* __restrict__ smem_reduce) @@ -121,9 +120,9 @@ struct SpargeBlockMapPipeline const index_t k_idx = lane_id % KThreads; const index_t m_idx = lane_id / KThreads; - // B2.v3 column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8, - // changing per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0, - // lanes (k_idx={0,4,8,12}) now hit banks {0,4,8,12} instead of all 0. + // Column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8, changing + // per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0, lanes + // (k_idx={0,4,8,12}) hit banks {0,4,8,12} instead of all 0. if(m_idx == 0) for(index_t k = 0; k < KPerThread; ++k) smem_reduce[warp_id * kPerWarpFloats + k_idx * kColPaddedStride + k] = col_acc[k]; @@ -268,7 +267,7 @@ struct SpargeBlockMapPipeline { const index_t tid = static_cast(threadIdx.x); - // R20: K-loop no longer reduces, only Phase 1 uses smem_float0. + // K-loop no longer reduces; only Q-stats uses smem_float0. // smem_float1 slab is allocated for layout compat but unused. auto* smem_float0 = reinterpret_cast(smem_ptr); auto* smem_scores = @@ -282,7 +281,7 @@ struct SpargeBlockMapPipeline const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast(bs_q)) : 0.f; // ================================================================== - // Phase 1: Q Block Statistics + // Q Block Statistics // ================================================================== auto q_tile = load_tile(q_window_in); @@ -294,9 +293,8 @@ struct SpargeBlockMapPipeline row_reduce_sq_norm(q_data, psq, bs_q); // 1b. Column sum -> mean - // Track F (re-apply R8 b2): drop trailing sync. Next reduce reuses same slab - // (smem_float0) and has its own leading __syncthreads() before reading. - // pooled_q_mean is register-only between reduces. + // Drop trailing sync: next reduce reuses same slab (smem_float0) with its own + // leading __syncthreads() before reading. pooled_q_mean is register-only between reduces. float pooled_q_mean[KPerThread]; column_reduce_thread_and_warp(q_data, pooled_q_mean); column_reduce_cross_warp(pooled_q_mean, smem_float0); @@ -304,9 +302,9 @@ struct SpargeBlockMapPipeline pooled_q_mean[k] *= inv_bs_q; // 1c. Normalised sum_hat - // Track F (re-apply R8 b2): drop trailing sync. Next cross-warp reduce in - // K-loop iter 0 writes slab_a=smem_float0 (kb=0 even). Although same slab, - // its leading __syncthreads() covers the WAR. sum_hat register-only here. + // Drop trailing sync: next cross-warp reduce in K-loop iter 0 writes + // slab_a=smem_float0 (kb=0 even); its leading __syncthreads() covers the WAR. + // sum_hat is register-only here. float sum_hat[KPerThread]; column_reduce_normalised(q_data, psq, sum_hat, bs_q); column_reduce_cross_warp(sum_hat, smem_float0); @@ -342,15 +340,15 @@ struct SpargeBlockMapPipeline } // ================================================================== - // Phase 2: K Block Loop + // K Block Loop // ================================================================== for(index_t i = tid; i < N_k; i += kBlockSize) smem_bmap[i] = 0; __syncthreads(); - // R20: K-stats precomputed by Kernel A. Each thread loads its own - // KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single - // byte. No K-tile load, no cross-warp reduce in the K-loop. + // K-stats precomputed by SpargeKStatsKernel. Each thread loads its own + // KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single byte. + // No K-tile load, no cross-warp reduce in the K-loop. const index_t lane_id_kb = tid % WarpSize; const index_t k_idx_kb = lane_id_kb % KThreads; @@ -372,10 +370,10 @@ struct SpargeBlockMapPipeline { // INVARIANT (mirrors SpargeAttn ref utils.py:175-180): // ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1) - // AND have score = -inf so Phase 3 selection (topk / cdf) does NOT + // AND have score = -inf so the selection step (topk / cdf) does NOT // pick them again (would double-count toward topk budget). - // Both writes MUST stay together. Any Phase 3 selection rewrite - // (e.g. iterative argmax → bitonic sort) must keep the -inf write. + // Both writes MUST stay together. Any selection rewrite + // (e.g. iterative argmax -> bitonic sort) must keep the -inf write. if(!sim_k) { smem_bmap[kb] = 1; @@ -387,10 +385,10 @@ struct SpargeBlockMapPipeline } } } - __syncthreads(); // guard Phase 3's reads of smem_bmap / smem_scores + __syncthreads(); // guard selection's reads of smem_bmap / smem_scores // ================================================================== - // Phase 3: Softmax + Selection + // Softmax + Selection // ================================================================== // max