cleanup(sparse_attn): R-tag rename + clang-format sweep

Strip internal R-tag / phase labels (R20, R21A/B, Round 8/13f, Track F, B2.v3,
Phase 1/2/3) from comments — replace with descriptive names so future readers
don't need the change-log. Reflow long signature in fmha_fwd_trek.hpp.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
Gino Lu
2026-05-17 02:35:07 -04:00
parent 7103eacc99
commit 879d50836e
2 changed files with 33 additions and 33 deletions

View File

@@ -283,7 +283,9 @@ float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
template <typename Traits_>
void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits,
fmha_jenga_fwd_args,
const ck_tile::stream_config&);
// VSA uses the same traits structure as Jenga; aliases for clarity
template <ck_tile::index_t HDim_,

View File

@@ -31,20 +31,20 @@ struct SpargeBlockMapPipeline
static constexpr index_t kBlockPerCu = 1;
static constexpr index_t kMaxKBlocks = 1024;
// LDS layout (non-overlapping, all used simultaneously in Phase 2):
// LDS layout (non-overlapping, all used simultaneously in K-block loop):
// [0 .. kReduceBytes) cross-warp reduction scratch slab 0
// [kReduceBytes .. 2*kReduceBytes) cross-warp reduction scratch slab 1
// (Round 8 b1: ping-pong for K-loop double buffer)
// (ping-pong for K-loop double buffer)
// [kScoreOffset ..) scores[N_k]
// [kBmapOffset ..) block_map[N_k]
// [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats)
// B2.v3 column-stride pad: replace k_idx*KPerThread with k_idx*(KPerThread+1)
// to break the 4-way intra-warp bank conflict. New per-warp slab size:
// KThreads * (KPerThread + 1) floats.
// [kSmallOffset ..) softmax/selection argmax scratch (2*NumWarps
// floats)
// Column-stride pad: k_idx*(KPerThread+1) instead of k_idx*KPerThread to break
// the 4-way intra-warp bank conflict. Per-warp slab size: KThreads * (KPerThread + 1) floats.
static constexpr index_t kColPaddedStride = KPerThread + 1;
static constexpr index_t kPerWarpFloats = KThreads * kColPaddedStride;
static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float);
static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // Round 8 b1: 2 slabs
static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // 2 slabs (K-loop ping-pong)
static constexpr index_t kScoreOffset = kReduceTotalBytes;
static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float);
static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t);
@@ -106,11 +106,10 @@ struct SpargeBlockMapPipeline
}
// Cross-warp LDS reduction for column sums.
// Round 13f: templated TrailingSync flag. When false, the trailing __syncthreads()
// is dropped — only safe when the next access targets a *different* slab and the
// intervening work does not read smem_reduce. Used at the slab_b call in Phase 2
// K-loop, where the next iter's first cross-warp reduce writes to slab_a (different
// address) and is preceded by its own leading sync.
// Templated TrailingSync flag: when false, the trailing __syncthreads() is dropped —
// only safe when the next access targets a *different* slab and the intervening work
// does not read smem_reduce. Used at the slab_b call in the K-loop, where the next
// iter's first cross-warp reduce writes to slab_a and is preceded by its own leading sync.
template <bool TrailingSync = true>
CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread],
float* __restrict__ smem_reduce)
@@ -121,9 +120,9 @@ struct SpargeBlockMapPipeline
const index_t k_idx = lane_id % KThreads;
const index_t m_idx = lane_id / KThreads;
// B2.v3 column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8,
// changing per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0,
// lanes (k_idx={0,4,8,12}) now hit banks {0,4,8,12} instead of all 0.
// Column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8, changing
// per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0, lanes
// (k_idx={0,4,8,12}) hit banks {0,4,8,12} instead of all 0.
if(m_idx == 0)
for(index_t k = 0; k < KPerThread; ++k)
smem_reduce[warp_id * kPerWarpFloats + k_idx * kColPaddedStride + k] = col_acc[k];
@@ -268,7 +267,7 @@ struct SpargeBlockMapPipeline
{
const index_t tid = static_cast<index_t>(threadIdx.x);
// R20: K-loop no longer reduces, only Phase 1 uses smem_float0.
// K-loop no longer reduces; only Q-stats uses smem_float0.
// smem_float1 slab is allocated for layout compat but unused.
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
auto* smem_scores =
@@ -282,7 +281,7 @@ struct SpargeBlockMapPipeline
const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast<float>(bs_q)) : 0.f;
// ==================================================================
// Phase 1: Q Block Statistics
// Q Block Statistics
// ==================================================================
auto q_tile = load_tile(q_window_in);
@@ -294,9 +293,8 @@ struct SpargeBlockMapPipeline
row_reduce_sq_norm<MPerThread>(q_data, psq, bs_q);
// 1b. Column sum -> mean
// Track F (re-apply R8 b2): drop trailing sync. Next reduce reuses same slab
// (smem_float0) and has its own leading __syncthreads() before reading.
// pooled_q_mean is register-only between reduces.
// Drop trailing sync: next reduce reuses same slab (smem_float0) with its own
// leading __syncthreads() before reading. pooled_q_mean is register-only between reduces.
float pooled_q_mean[KPerThread];
column_reduce_thread_and_warp<MPerThread>(q_data, pooled_q_mean);
column_reduce_cross_warp<false>(pooled_q_mean, smem_float0);
@@ -304,9 +302,9 @@ struct SpargeBlockMapPipeline
pooled_q_mean[k] *= inv_bs_q;
// 1c. Normalised sum_hat
// Track F (re-apply R8 b2): drop trailing sync. Next cross-warp reduce in
// K-loop iter 0 writes slab_a=smem_float0 (kb=0 even). Although same slab,
// its leading __syncthreads() covers the WAR. sum_hat register-only here.
// Drop trailing sync: next cross-warp reduce in K-loop iter 0 writes
// slab_a=smem_float0 (kb=0 even); its leading __syncthreads() covers the WAR.
// sum_hat is register-only here.
float sum_hat[KPerThread];
column_reduce_normalised<MPerThread>(q_data, psq, sum_hat, bs_q);
column_reduce_cross_warp<false>(sum_hat, smem_float0);
@@ -342,15 +340,15 @@ struct SpargeBlockMapPipeline
}
// ==================================================================
// Phase 2: K Block Loop
// K Block Loop
// ==================================================================
for(index_t i = tid; i < N_k; i += kBlockSize)
smem_bmap[i] = 0;
__syncthreads();
// R20: K-stats precomputed by Kernel A. Each thread loads its own
// KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single
// byte. No K-tile load, no cross-warp reduce in the K-loop.
// K-stats precomputed by SpargeKStatsKernel. Each thread loads its own
// KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single byte.
// No K-tile load, no cross-warp reduce in the K-loop.
const index_t lane_id_kb = tid % WarpSize;
const index_t k_idx_kb = lane_id_kb % KThreads;
@@ -372,10 +370,10 @@ struct SpargeBlockMapPipeline
{
// INVARIANT (mirrors SpargeAttn ref utils.py:175-180):
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
// AND have score = -inf so Phase 3 selection (topk / cdf) does NOT
// AND have score = -inf so the selection step (topk / cdf) does NOT
// pick them again (would double-count toward topk budget).
// Both writes MUST stay together. Any Phase 3 selection rewrite
// (e.g. iterative argmax bitonic sort) must keep the -inf write.
// Both writes MUST stay together. Any selection rewrite
// (e.g. iterative argmax -> bitonic sort) must keep the -inf write.
if(!sim_k)
{
smem_bmap[kb] = 1;
@@ -387,10 +385,10 @@ struct SpargeBlockMapPipeline
}
}
}
__syncthreads(); // guard Phase 3's reads of smem_bmap / smem_scores
__syncthreads(); // guard selection's reads of smem_bmap / smem_scores
// ==================================================================
// Phase 3: Softmax + Selection
// Softmax + Selection
// ==================================================================
// max