mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
sparse_attn: split KStats kernel, add README + perf charts
- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
per-block K stats workspace consumed by Kernel B), removing redundant
K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
+ reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
@@ -52,7 +52,20 @@ struct SpargeBlockMapKernel
|
||||
void* lut_ptr;
|
||||
void* valid_block_num_ptr;
|
||||
|
||||
// R20 K-stat workspace from Kernel A
|
||||
const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32
|
||||
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
|
||||
|
||||
index_t N_k;
|
||||
|
||||
// R21A Phase 4: optional per-head topk (size = nhead_q floats).
|
||||
// nullptr => use scalar `topk` for all heads.
|
||||
const float* topk_per_head;
|
||||
|
||||
// R21B: optional per-head cdfthreshd (size = nhead_q floats).
|
||||
// nullptr => use scalar `cdfthreshd` for all heads.
|
||||
// Only consulted on topk<=0 path; bench currently always uses topk path.
|
||||
const float* cdfthreshd_per_head;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr,
|
||||
@@ -74,7 +87,11 @@ struct SpargeBlockMapKernel
|
||||
float scale,
|
||||
void* block_map_ptr,
|
||||
void* lut_ptr,
|
||||
void* valid_block_num_ptr)
|
||||
void* valid_block_num_ptr,
|
||||
const void* pooled_k_ws_ptr,
|
||||
const void* sim_k_ws_ptr,
|
||||
const float* topk_per_head = nullptr,
|
||||
const float* cdfthreshd_per_head = nullptr)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{q_ptr,
|
||||
@@ -97,7 +114,11 @@ struct SpargeBlockMapKernel
|
||||
block_map_ptr,
|
||||
lut_ptr,
|
||||
valid_block_num_ptr,
|
||||
N_k};
|
||||
pooled_k_ws_ptr,
|
||||
sim_k_ws_ptr,
|
||||
N_k,
|
||||
topk_per_head,
|
||||
cdfthreshd_per_head};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q)
|
||||
@@ -174,6 +195,21 @@ struct SpargeBlockMapKernel
|
||||
// Shared memory
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
|
||||
// R20 K-stat workspace: pre-offset for this (b, hk).
|
||||
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
|
||||
const index_t khead_off = (b * nhead_k + hk) * N_k;
|
||||
const auto* pooled_k_ws =
|
||||
reinterpret_cast<const float*>(kargs.pooled_k_ws_ptr) + khead_off * D;
|
||||
const auto* sim_k_ws =
|
||||
reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
|
||||
|
||||
// R21A Phase 4: per-head topk if provided, else scalar broadcast.
|
||||
const float topk_eff =
|
||||
(kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk;
|
||||
// R21B: per-head cdfthreshd if provided, else scalar broadcast.
|
||||
const float cdfthreshd_eff =
|
||||
(kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd;
|
||||
|
||||
Pipeline{}(q_window,
|
||||
k_window,
|
||||
kargs.seqlen_q,
|
||||
@@ -182,12 +218,14 @@ struct SpargeBlockMapKernel
|
||||
N_k,
|
||||
kargs.nhead_ratio_qk,
|
||||
kargs.simthreshd1,
|
||||
kargs.cdfthreshd,
|
||||
kargs.topk,
|
||||
cdfthreshd_eff,
|
||||
topk_eff,
|
||||
kargs.scale,
|
||||
bmap_ptr,
|
||||
lut_out,
|
||||
valid_out,
|
||||
pooled_k_ws,
|
||||
sim_k_ws,
|
||||
static_cast<void*>(smem));
|
||||
}
|
||||
};
|
||||
|
||||
136
include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp
Normal file
136
include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp
Normal file
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Kernel A wrapper: grid (N_k, nhead_k, batch). Each work-group precomputes
|
||||
// K-block stats (pooled_k_mean[D], sim_k) for one (b, hk, kb) into a workspace
|
||||
// that Kernel B (block_map) reads instead of recomputing per Q-block.
|
||||
template <typename Pipeline_>
|
||||
struct SpargeKStatsKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
|
||||
static constexpr index_t kBlockSize = Pipeline::kBlockSize;
|
||||
static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
|
||||
|
||||
using QDataType = typename Pipeline::QDataType;
|
||||
using KDataType = typename Pipeline::KDataType;
|
||||
|
||||
static constexpr index_t kN0 = Pipeline::kN0;
|
||||
static constexpr index_t D = Pipeline::D;
|
||||
|
||||
static constexpr index_t kAlignment = 16 / sizeof(KDataType);
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* k_ptr;
|
||||
|
||||
index_t seqlen_k;
|
||||
index_t hdim_q;
|
||||
index_t nhead_k;
|
||||
|
||||
index_t stride_k;
|
||||
index_t nhead_stride_k;
|
||||
index_t batch_stride_k;
|
||||
|
||||
float simthreshd1;
|
||||
|
||||
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32
|
||||
void* sim_k_ptr; // [batch, nhead_k, N_k] uint8
|
||||
|
||||
index_t N_k;
|
||||
|
||||
// R21A Phase 4 + R21B fix: optional per-head simthreshd1.
|
||||
// Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract
|
||||
// (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first
|
||||
// nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`.
|
||||
const float* simthreshd1_per_head;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const void* k_ptr,
|
||||
index_t seqlen_k,
|
||||
index_t hdim_q,
|
||||
index_t nhead_k,
|
||||
index_t stride_k,
|
||||
index_t nhead_stride_k,
|
||||
index_t batch_stride_k,
|
||||
float simthreshd1,
|
||||
void* pooled_k_ptr,
|
||||
void* sim_k_ptr,
|
||||
const float* simthreshd1_per_head = nullptr)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{k_ptr,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
nhead_k,
|
||||
stride_k,
|
||||
nhead_stride_k,
|
||||
batch_stride_k,
|
||||
simthreshd1,
|
||||
pooled_k_ptr,
|
||||
sim_k_ptr,
|
||||
N_k,
|
||||
simthreshd1_per_head};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_k, index_t seqlen_k)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return dim3(N_k, nhead_k, batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const index_t kb = static_cast<index_t>(blockIdx.x);
|
||||
const index_t hk = static_cast<index_t>(blockIdx.y);
|
||||
const index_t b = static_cast<index_t>(blockIdx.z);
|
||||
|
||||
const auto* k_base = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
b * kargs.batch_stride_k + hk * kargs.nhead_stride_k +
|
||||
kb * kN0 * kargs.stride_k;
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_base,
|
||||
make_tuple(kargs.seqlen_k - kb * kN0, D),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<kAlignment>{},
|
||||
number<1>{});
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive, make_tuple(number<kN0>{}, number<D>{}), sequence<true, false>{});
|
||||
|
||||
auto k_window = make_tile_window(k_dram,
|
||||
make_tuple(number<kN0>{}, number<D>{}),
|
||||
{0, 0},
|
||||
Pipeline::MakeKBlockDistribution());
|
||||
|
||||
const index_t N_k = kargs.N_k;
|
||||
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
|
||||
auto* pooled_k_out = reinterpret_cast<float*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
|
||||
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
|
||||
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
|
||||
// R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast.
|
||||
const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr)
|
||||
? kargs.simthreshd1_per_head[hk]
|
||||
: kargs.simthreshd1;
|
||||
|
||||
Pipeline{}(k_window,
|
||||
kargs.seqlen_k,
|
||||
kb,
|
||||
simthreshd1_eff,
|
||||
pooled_k_out,
|
||||
sim_k_out,
|
||||
static_cast<void*>(smem));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -32,14 +32,22 @@ struct SpargeBlockMapPipeline
|
||||
static constexpr index_t kMaxKBlocks = 1024;
|
||||
|
||||
// LDS layout (non-overlapping, all used simultaneously in Phase 2):
|
||||
// [0 .. kReduceBytes) cross-warp reduction scratch
|
||||
// [kScoreOffset ..) scores[N_k]
|
||||
// [kBmapOffset ..) block_map[N_k]
|
||||
// [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats)
|
||||
static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float);
|
||||
static constexpr index_t kScoreOffset = kReduceBytes;
|
||||
static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float);
|
||||
static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t);
|
||||
// [0 .. kReduceBytes) cross-warp reduction scratch slab 0
|
||||
// [kReduceBytes .. 2*kReduceBytes) cross-warp reduction scratch slab 1
|
||||
// (Round 8 b1: ping-pong for K-loop double buffer)
|
||||
// [kScoreOffset ..) scores[N_k]
|
||||
// [kBmapOffset ..) block_map[N_k]
|
||||
// [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats)
|
||||
// B2.v3 column-stride pad: replace k_idx*KPerThread with k_idx*(KPerThread+1)
|
||||
// to break the 4-way intra-warp bank conflict. New per-warp slab size:
|
||||
// KThreads * (KPerThread + 1) floats.
|
||||
static constexpr index_t kColPaddedStride = KPerThread + 1;
|
||||
static constexpr index_t kPerWarpFloats = KThreads * kColPaddedStride;
|
||||
static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float);
|
||||
static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // Round 8 b1: 2 slabs
|
||||
static constexpr index_t kScoreOffset = kReduceTotalBytes;
|
||||
static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float);
|
||||
static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t);
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
@@ -98,6 +106,12 @@ struct SpargeBlockMapPipeline
|
||||
}
|
||||
|
||||
// Cross-warp LDS reduction for column sums.
|
||||
// Round 13f: templated TrailingSync flag. When false, the trailing __syncthreads()
|
||||
// is dropped — only safe when the next access targets a *different* slab and the
|
||||
// intervening work does not read smem_reduce. Used at the slab_b call in Phase 2
|
||||
// K-loop, where the next iter's first cross-warp reduce writes to slab_a (different
|
||||
// address) and is preceded by its own leading sync.
|
||||
template <bool TrailingSync = true>
|
||||
CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread],
|
||||
float* __restrict__ smem_reduce)
|
||||
{
|
||||
@@ -107,17 +121,21 @@ struct SpargeBlockMapPipeline
|
||||
const index_t k_idx = lane_id % KThreads;
|
||||
const index_t m_idx = lane_id / KThreads;
|
||||
|
||||
// B2.v3 column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8,
|
||||
// changing per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0,
|
||||
// lanes (k_idx={0,4,8,12}) now hit banks {0,4,8,12} instead of all 0.
|
||||
if(m_idx == 0)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k];
|
||||
smem_reduce[warp_id * kPerWarpFloats + k_idx * kColPaddedStride + k] = col_acc[k];
|
||||
__syncthreads();
|
||||
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] = 0.f;
|
||||
for(index_t w = 0; w < NumWarps; ++w)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k];
|
||||
__syncthreads();
|
||||
col_acc[k] += smem_reduce[w * kPerWarpFloats + k_idx * kColPaddedStride + k];
|
||||
if constexpr(TrailingSync)
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx.
|
||||
@@ -162,7 +180,8 @@ struct SpargeBlockMapPipeline
|
||||
|
||||
for(index_t m = 0; m < SeqPerThread; ++m)
|
||||
{
|
||||
float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f;
|
||||
// Round 12: hardware fast rsqrt (v_rsq_f32, ~1 ULP) replaces sw sqrt+rcp.
|
||||
float inv_norm = (row_norms[m] > 0.f) ? rsqrtf(row_norms[m]) : 0.f;
|
||||
index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx;
|
||||
if(gsq < actual_seq)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
@@ -230,9 +249,9 @@ struct SpargeBlockMapPipeline
|
||||
// ======================================================================
|
||||
template <typename QWindowType, typename KWindowType>
|
||||
CK_TILE_DEVICE void operator()(const QWindowType& q_window_in,
|
||||
const KWindowType& k_window_in,
|
||||
const KWindowType& /*k_window_in*/,
|
||||
index_t seqlen_q,
|
||||
index_t seqlen_k,
|
||||
index_t /*seqlen_k*/,
|
||||
index_t qb,
|
||||
index_t N_k,
|
||||
index_t /*nhead_ratio_qk*/,
|
||||
@@ -243,11 +262,15 @@ struct SpargeBlockMapPipeline
|
||||
uint8_t* block_map_ptr,
|
||||
int32_t* lut_ptr,
|
||||
int32_t* valid_block_num_ptr,
|
||||
const float* __restrict__ pooled_k_ws_ptr,
|
||||
const uint8_t* __restrict__ sim_k_ws_ptr,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
|
||||
auto* smem_float = reinterpret_cast<float*>(smem_ptr);
|
||||
// R20: K-loop no longer reduces, only Phase 1 uses smem_float0.
|
||||
// smem_float1 slab is allocated for layout compat but unused.
|
||||
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
|
||||
auto* smem_scores =
|
||||
reinterpret_cast<float*>(reinterpret_cast<char*>(smem_ptr) + kScoreOffset);
|
||||
auto* smem_bmap =
|
||||
@@ -271,16 +294,22 @@ struct SpargeBlockMapPipeline
|
||||
row_reduce_sq_norm<MPerThread>(q_data, psq, bs_q);
|
||||
|
||||
// 1b. Column sum -> mean
|
||||
// Track F (re-apply R8 b2): drop trailing sync. Next reduce reuses same slab
|
||||
// (smem_float0) and has its own leading __syncthreads() before reading.
|
||||
// pooled_q_mean is register-only between reduces.
|
||||
float pooled_q_mean[KPerThread];
|
||||
column_reduce_thread_and_warp<MPerThread>(q_data, pooled_q_mean);
|
||||
column_reduce_cross_warp(pooled_q_mean, smem_float);
|
||||
column_reduce_cross_warp<false>(pooled_q_mean, smem_float0);
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_q_mean[k] *= inv_bs_q;
|
||||
|
||||
// 1c. Normalised sum_hat
|
||||
// Track F (re-apply R8 b2): drop trailing sync. Next cross-warp reduce in
|
||||
// K-loop iter 0 writes slab_a=smem_float0 (kb=0 even). Although same slab,
|
||||
// its leading __syncthreads() covers the WAR. sum_hat register-only here.
|
||||
float sum_hat[KPerThread];
|
||||
column_reduce_normalised<MPerThread>(q_data, psq, sum_hat, bs_q);
|
||||
column_reduce_cross_warp(sum_hat, smem_float);
|
||||
column_reduce_cross_warp<false>(sum_hat, smem_float0);
|
||||
|
||||
// 1d. sim_q = ||sum_hat||^2 / bs_q^2
|
||||
float sh_sq = 0.f;
|
||||
@@ -319,49 +348,34 @@ struct SpargeBlockMapPipeline
|
||||
smem_bmap[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
auto k_window = k_window_in;
|
||||
// R20: K-stats precomputed by Kernel A. Each thread loads its own
|
||||
// KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single
|
||||
// byte. No K-tile load, no cross-warp reduce in the K-loop.
|
||||
const index_t lane_id_kb = tid % WarpSize;
|
||||
const index_t k_idx_kb = lane_id_kb % KThreads;
|
||||
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
const index_t bs_k = min(static_cast<index_t>(kN0), seqlen_k - kb * kN0);
|
||||
const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast<float>(bs_k)) : 0.f;
|
||||
|
||||
auto k_tile = load_tile(k_window);
|
||||
|
||||
float k_data[NPerThread * KPerThread];
|
||||
tile_to_float<NPerThread * KPerThread>(k_tile, k_data);
|
||||
|
||||
// K mean
|
||||
const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
|
||||
float pooled_k_mean[KPerThread];
|
||||
column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
|
||||
column_reduce_cross_warp(pooled_k_mean, smem_float);
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_mean[k] *= inv_bs_k;
|
||||
pooled_k_mean[k] = p_kb[k];
|
||||
|
||||
// dot(pooled_q_mean, pooled_k_mean)
|
||||
float dot = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
dot += pooled_q_mean[k] * pooled_k_mean[k];
|
||||
dot = reduce_across_k(dot);
|
||||
|
||||
// K L2 norms + normalised sum_hat
|
||||
float k_psq[NPerThread];
|
||||
row_reduce_sq_norm<NPerThread>(k_data, k_psq, bs_k);
|
||||
|
||||
float k_sum_hat[KPerThread];
|
||||
column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
|
||||
column_reduce_cross_warp(k_sum_hat, smem_float);
|
||||
|
||||
// sim_k
|
||||
float ksh_sq = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
ksh_sq += k_sum_hat[k] * k_sum_hat[k];
|
||||
ksh_sq = reduce_across_k(ksh_sq);
|
||||
const float denom_k = static_cast<float>(bs_k) * static_cast<float>(bs_k);
|
||||
const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1);
|
||||
const bool sim_k = (sim_k_ws_ptr[kb] != 0);
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
// INVARIANT (mirrors SpargeAttn ref utils.py:175-180):
|
||||
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
|
||||
// AND have score = -inf so Phase 3 selection (topk / cdf) does NOT
|
||||
// pick them again (would double-count toward topk budget).
|
||||
// Both writes MUST stay together. Any Phase 3 selection rewrite
|
||||
// (e.g. iterative argmax → bitonic sort) must keep the -inf write.
|
||||
if(!sim_k)
|
||||
{
|
||||
smem_bmap[kb] = 1;
|
||||
@@ -372,10 +386,8 @@ struct SpargeBlockMapPipeline
|
||||
smem_scores[kb] = dot * scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
move_tile_window(k_window, {kN0, 0});
|
||||
}
|
||||
__syncthreads(); // guard Phase 3's reads of smem_bmap / smem_scores
|
||||
|
||||
// ==================================================================
|
||||
// Phase 3: Softmax + Selection
|
||||
@@ -399,15 +411,24 @@ struct SpargeBlockMapPipeline
|
||||
}
|
||||
const float sum_exp = block_reduce_sum(lsum, smem_small);
|
||||
|
||||
// normalise
|
||||
const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
smem_scores[i] *= inv_sum;
|
||||
__syncthreads();
|
||||
// Round 13i: argmax is invariant under positive scaling (inv_sum > 0). When
|
||||
// topk > 0 we never read normalised values for cdfthreshd, so skip the
|
||||
// normalise pass entirely (saves N_k LDS writes + 1 __syncthreads). The
|
||||
// cdfthreshd path (topk <= 0) still requires normalised scores so the
|
||||
// accumulator `cumulative_prob` matches probabilities.
|
||||
const bool topk_active = (topk > 0.f);
|
||||
const float inv_sum =
|
||||
(!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
|
||||
if(!topk_active)
|
||||
{
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
smem_scores[i] *= inv_sum;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Selection: iterative argmax
|
||||
index_t num_to_select =
|
||||
(topk > 0.f)
|
||||
topk_active
|
||||
? max(static_cast<index_t>(1), static_cast<index_t>(topk * static_cast<float>(N_k)))
|
||||
: N_k;
|
||||
|
||||
@@ -448,6 +469,11 @@ struct SpargeBlockMapPipeline
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Round 13g: collapse 2 syncs/round into 1. tid==0 computes the global
|
||||
// winner AND writes the sentinel (smem_bmap=1, smem_scores=-1) in the same
|
||||
// critical section, gated by bv>0. All threads then read smem_small[0] for
|
||||
// the early break / cumulative_prob accumulation. Saves 1 __syncthreads per
|
||||
// round (~32 syncs @ N_k=64 topk=0.5).
|
||||
if(tid == 0)
|
||||
{
|
||||
float bv = smem_small[0];
|
||||
@@ -462,24 +488,22 @@ struct SpargeBlockMapPipeline
|
||||
bi = wi;
|
||||
}
|
||||
}
|
||||
// Write sentinel into bmap/scores in the same critical section.
|
||||
// Guarded by bv > 0 so we never poison a valid score with -1.
|
||||
if(bv > 0.f)
|
||||
{
|
||||
smem_bmap[bi] = 1;
|
||||
smem_scores[bi] = -1.f;
|
||||
}
|
||||
smem_small[0] = bv;
|
||||
smem_small[1] = bit_cast<float>(static_cast<int32_t>(bi));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float g_val = smem_small[0];
|
||||
index_t g_idx = bit_cast<int32_t>(smem_small[1]);
|
||||
float g_val = smem_small[0];
|
||||
|
||||
if(g_val <= 0.f)
|
||||
break;
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
smem_bmap[g_idx] = 1;
|
||||
smem_scores[g_idx] = -1.f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(topk > 0.f)
|
||||
{
|
||||
if(round + 1 >= num_to_select)
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Kernel A of the K-stat precompute split: one work-group per (b, hk, kb)
|
||||
// computes pooled_k_mean and sim_k for that K-block once. Kernel B then reads
|
||||
// from the workspace instead of recomputing per Q-block.
|
||||
template <typename Problem_>
|
||||
struct SpargeKStatsPipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Base = SpargeBlockMapPipeline<Problem>;
|
||||
using QDataType = typename Base::QDataType;
|
||||
using KDataType = typename Base::KDataType;
|
||||
|
||||
static constexpr index_t kBlockSize = Base::kBlockSize;
|
||||
static constexpr index_t kM0 = Base::kM0;
|
||||
static constexpr index_t kN0 = Base::kN0;
|
||||
static constexpr index_t D = Base::D;
|
||||
static constexpr index_t NumWarps = Base::NumWarps;
|
||||
static constexpr index_t WarpSize = Base::WarpSize;
|
||||
|
||||
static constexpr index_t KPerThread = Base::KPerThread;
|
||||
static constexpr index_t KThreads = Base::KThreads;
|
||||
static constexpr index_t SeqThreadPerWarp = Base::SeqThreadPerWarp;
|
||||
static constexpr index_t NPerThread = Base::NPerThread;
|
||||
|
||||
static constexpr index_t kBlockPerCu = 1;
|
||||
|
||||
static constexpr index_t kColPaddedStride = Base::kColPaddedStride;
|
||||
static constexpr index_t kPerWarpFloats = Base::kPerWarpFloats;
|
||||
static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float);
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kReduceBytes; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution()
|
||||
{
|
||||
return Base::MakeKBlockDistribution();
|
||||
}
|
||||
|
||||
// operator(): one work-group, one K-block. Writes D fp32 + 1 uint8 to workspace.
|
||||
template <typename KWindowType>
|
||||
CK_TILE_DEVICE void operator()(const KWindowType& k_window,
|
||||
index_t seqlen_k,
|
||||
index_t kb,
|
||||
float simthreshd1,
|
||||
float* __restrict__ pooled_k_out, // D floats
|
||||
uint8_t* __restrict__ sim_k_out, // 1 byte
|
||||
void* smem_ptr) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
auto* smem_reduce = reinterpret_cast<float*>(smem_ptr);
|
||||
|
||||
const index_t bs_k = min(static_cast<index_t>(kN0), seqlen_k - kb * kN0);
|
||||
const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast<float>(bs_k)) : 0.f;
|
||||
|
||||
auto k_tile = load_tile(k_window);
|
||||
|
||||
float k_data[NPerThread * KPerThread];
|
||||
Base::template tile_to_float<NPerThread * KPerThread>(k_tile, k_data);
|
||||
|
||||
const index_t warp_id = tid / WarpSize;
|
||||
const index_t lane_id = tid % WarpSize;
|
||||
const index_t k_idx = lane_id % KThreads;
|
||||
const index_t m_idx = lane_id / KThreads;
|
||||
|
||||
// pooled_k_mean: column sum then cross-warp reduce.
|
||||
// R21A: drop trailing sync (next cross_warp_reduce has its own leading sync).
|
||||
float pooled_k_mean[KPerThread];
|
||||
Base::template column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
|
||||
Base::template column_reduce_cross_warp<false>(pooled_k_mean, smem_reduce);
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_mean[k] *= inv_bs_k;
|
||||
|
||||
// R21A: write pooled_k_mean to global early so its register liveness ends here,
|
||||
// freeing VGPR before k_sum_hat becomes live.
|
||||
if(warp_id == 0 && m_idx == 0)
|
||||
{
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k];
|
||||
}
|
||||
|
||||
// K row L2 norms + normalised column sum (k_sum_hat)
|
||||
float k_psq[NPerThread];
|
||||
Base::template row_reduce_sq_norm<NPerThread>(k_data, k_psq, bs_k);
|
||||
|
||||
float k_sum_hat[KPerThread];
|
||||
Base::template column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
|
||||
// R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write).
|
||||
Base::template column_reduce_cross_warp<false>(k_sum_hat, smem_reduce);
|
||||
|
||||
// sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1
|
||||
float ksh_sq = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
ksh_sq += k_sum_hat[k] * k_sum_hat[k];
|
||||
ksh_sq = Base::reduce_across_k(ksh_sq);
|
||||
const float denom_k = static_cast<float>(bs_k) * static_cast<float>(bs_k);
|
||||
const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1);
|
||||
|
||||
if(tid == 0)
|
||||
*sim_k_out = sim_k ? static_cast<uint8_t>(1) : static_cast<uint8_t>(0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user