mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
cleanup(sparse_attn): R-tag rename + clang-format sweep
Strip internal R-tag / phase labels (R20, R21A/B, Round 8/13f, Track F, B2.v3, Phase 1/2/3) from comments — replace with descriptive names so future readers don't need the change-log. Reflow long signature in fmha_fwd_trek.hpp. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
@@ -283,7 +283,9 @@ float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
|
||||
template <typename Traits_>
|
||||
void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args);
|
||||
|
||||
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&);
|
||||
void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits,
|
||||
fmha_jenga_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
// VSA uses the same traits structure as Jenga; aliases for clarity
|
||||
template <ck_tile::index_t HDim_,
|
||||
|
||||
@@ -31,20 +31,20 @@ struct SpargeBlockMapPipeline
|
||||
static constexpr index_t kBlockPerCu = 1;
|
||||
static constexpr index_t kMaxKBlocks = 1024;
|
||||
|
||||
// LDS layout (non-overlapping, all used simultaneously in Phase 2):
|
||||
// LDS layout (non-overlapping, all used simultaneously in K-block loop):
|
||||
// [0 .. kReduceBytes) cross-warp reduction scratch slab 0
|
||||
// [kReduceBytes .. 2*kReduceBytes) cross-warp reduction scratch slab 1
|
||||
// (Round 8 b1: ping-pong for K-loop double buffer)
|
||||
// (ping-pong for K-loop double buffer)
|
||||
// [kScoreOffset ..) scores[N_k]
|
||||
// [kBmapOffset ..) block_map[N_k]
|
||||
// [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats)
|
||||
// B2.v3 column-stride pad: replace k_idx*KPerThread with k_idx*(KPerThread+1)
|
||||
// to break the 4-way intra-warp bank conflict. New per-warp slab size:
|
||||
// KThreads * (KPerThread + 1) floats.
|
||||
// [kSmallOffset ..) softmax/selection argmax scratch (2*NumWarps
|
||||
// floats)
|
||||
// Column-stride pad: k_idx*(KPerThread+1) instead of k_idx*KPerThread to break
|
||||
// the 4-way intra-warp bank conflict. Per-warp slab size: KThreads * (KPerThread + 1) floats.
|
||||
static constexpr index_t kColPaddedStride = KPerThread + 1;
|
||||
static constexpr index_t kPerWarpFloats = KThreads * kColPaddedStride;
|
||||
static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float);
|
||||
static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // Round 8 b1: 2 slabs
|
||||
static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // 2 slabs (K-loop ping-pong)
|
||||
static constexpr index_t kScoreOffset = kReduceTotalBytes;
|
||||
static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float);
|
||||
static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t);
|
||||
@@ -106,11 +106,10 @@ struct SpargeBlockMapPipeline
|
||||
}
|
||||
|
||||
// Cross-warp LDS reduction for column sums.
|
||||
// Round 13f: templated TrailingSync flag. When false, the trailing __syncthreads()
|
||||
// is dropped — only safe when the next access targets a *different* slab and the
|
||||
// intervening work does not read smem_reduce. Used at the slab_b call in Phase 2
|
||||
// K-loop, where the next iter's first cross-warp reduce writes to slab_a (different
|
||||
// address) and is preceded by its own leading sync.
|
||||
// Templated TrailingSync flag: when false, the trailing __syncthreads() is dropped —
|
||||
// only safe when the next access targets a *different* slab and the intervening work
|
||||
// does not read smem_reduce. Used at the slab_b call in the K-loop, where the next
|
||||
// iter's first cross-warp reduce writes to slab_a and is preceded by its own leading sync.
|
||||
template <bool TrailingSync = true>
|
||||
CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread],
|
||||
float* __restrict__ smem_reduce)
|
||||
@@ -121,9 +120,9 @@ struct SpargeBlockMapPipeline
|
||||
const index_t k_idx = lane_id % KThreads;
|
||||
const index_t m_idx = lane_id / KThreads;
|
||||
|
||||
// B2.v3 column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8,
|
||||
// changing per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0,
|
||||
// lanes (k_idx={0,4,8,12}) now hit banks {0,4,8,12} instead of all 0.
|
||||
// Column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8, changing
|
||||
// per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0, lanes
|
||||
// (k_idx={0,4,8,12}) hit banks {0,4,8,12} instead of all 0.
|
||||
if(m_idx == 0)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
smem_reduce[warp_id * kPerWarpFloats + k_idx * kColPaddedStride + k] = col_acc[k];
|
||||
@@ -268,7 +267,7 @@ struct SpargeBlockMapPipeline
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
|
||||
// R20: K-loop no longer reduces, only Phase 1 uses smem_float0.
|
||||
// K-loop no longer reduces; only Q-stats uses smem_float0.
|
||||
// smem_float1 slab is allocated for layout compat but unused.
|
||||
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
|
||||
auto* smem_scores =
|
||||
@@ -282,7 +281,7 @@ struct SpargeBlockMapPipeline
|
||||
const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast<float>(bs_q)) : 0.f;
|
||||
|
||||
// ==================================================================
|
||||
// Phase 1: Q Block Statistics
|
||||
// Q Block Statistics
|
||||
// ==================================================================
|
||||
auto q_tile = load_tile(q_window_in);
|
||||
|
||||
@@ -294,9 +293,8 @@ struct SpargeBlockMapPipeline
|
||||
row_reduce_sq_norm<MPerThread>(q_data, psq, bs_q);
|
||||
|
||||
// 1b. Column sum -> mean
|
||||
// Track F (re-apply R8 b2): drop trailing sync. Next reduce reuses same slab
|
||||
// (smem_float0) and has its own leading __syncthreads() before reading.
|
||||
// pooled_q_mean is register-only between reduces.
|
||||
// Drop trailing sync: next reduce reuses same slab (smem_float0) with its own
|
||||
// leading __syncthreads() before reading. pooled_q_mean is register-only between reduces.
|
||||
float pooled_q_mean[KPerThread];
|
||||
column_reduce_thread_and_warp<MPerThread>(q_data, pooled_q_mean);
|
||||
column_reduce_cross_warp<false>(pooled_q_mean, smem_float0);
|
||||
@@ -304,9 +302,9 @@ struct SpargeBlockMapPipeline
|
||||
pooled_q_mean[k] *= inv_bs_q;
|
||||
|
||||
// 1c. Normalised sum_hat
|
||||
// Track F (re-apply R8 b2): drop trailing sync. Next cross-warp reduce in
|
||||
// K-loop iter 0 writes slab_a=smem_float0 (kb=0 even). Although same slab,
|
||||
// its leading __syncthreads() covers the WAR. sum_hat register-only here.
|
||||
// Drop trailing sync: next cross-warp reduce in K-loop iter 0 writes
|
||||
// slab_a=smem_float0 (kb=0 even); its leading __syncthreads() covers the WAR.
|
||||
// sum_hat is register-only here.
|
||||
float sum_hat[KPerThread];
|
||||
column_reduce_normalised<MPerThread>(q_data, psq, sum_hat, bs_q);
|
||||
column_reduce_cross_warp<false>(sum_hat, smem_float0);
|
||||
@@ -342,15 +340,15 @@ struct SpargeBlockMapPipeline
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Phase 2: K Block Loop
|
||||
// K Block Loop
|
||||
// ==================================================================
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
smem_bmap[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// R20: K-stats precomputed by Kernel A. Each thread loads its own
|
||||
// KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single
|
||||
// byte. No K-tile load, no cross-warp reduce in the K-loop.
|
||||
// K-stats precomputed by SpargeKStatsKernel. Each thread loads its own
|
||||
// KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single byte.
|
||||
// No K-tile load, no cross-warp reduce in the K-loop.
|
||||
const index_t lane_id_kb = tid % WarpSize;
|
||||
const index_t k_idx_kb = lane_id_kb % KThreads;
|
||||
|
||||
@@ -372,10 +370,10 @@ struct SpargeBlockMapPipeline
|
||||
{
|
||||
// INVARIANT (mirrors SpargeAttn ref utils.py:175-180):
|
||||
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
|
||||
// AND have score = -inf so Phase 3 selection (topk / cdf) does NOT
|
||||
// AND have score = -inf so the selection step (topk / cdf) does NOT
|
||||
// pick them again (would double-count toward topk budget).
|
||||
// Both writes MUST stay together. Any Phase 3 selection rewrite
|
||||
// (e.g. iterative argmax → bitonic sort) must keep the -inf write.
|
||||
// Both writes MUST stay together. Any selection rewrite
|
||||
// (e.g. iterative argmax -> bitonic sort) must keep the -inf write.
|
||||
if(!sim_k)
|
||||
{
|
||||
smem_bmap[kb] = 1;
|
||||
@@ -387,10 +385,10 @@ struct SpargeBlockMapPipeline
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // guard Phase 3's reads of smem_bmap / smem_scores
|
||||
__syncthreads(); // guard selection's reads of smem_bmap / smem_scores
|
||||
|
||||
// ==================================================================
|
||||
// Phase 3: Softmax + Selection
|
||||
// Softmax + Selection
|
||||
// ==================================================================
|
||||
|
||||
// max
|
||||
|
||||
Reference in New Issue
Block a user