refactor(sparse_attn): caller-owned workspace + dtype-aware sizing

Replace process-lifetime lazy hipMalloc K-stats workspace with a caller-owned
buffer; expose sparge_blockmap_get_workspace_size() / compute_workspace_layout()
host helpers. Split the combined sparge_blockmap_fwd into stage launchers
(sparge_kstats_fwd_oneshot + sparge_blockmap_only_fwd_oneshot) so the chained
launch is timed end-to-end.

Make pooled_k storage dtype follow KDataType (fp16/bf16) instead of fp32 to halve
workspace footprint and match dense-FMHA precision. Tighten per-head superparam
pointers to required (non-null) and assert N_k <= 256 in jenga MakeKargs to
document the 256-bool LDS staging cap. Drop the obsolete VSA extra-LDS staging.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
Gino Lu
2026-05-17 02:34:23 -04:00
parent 668e107282
commit 7103eacc99
9 changed files with 402 additions and 399 deletions

View File

@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <cassert>
#include <string>
#include <type_traits>
#include <utility>
@@ -133,34 +134,41 @@ struct FmhaFwdJengaKernel
};
// std::variant<> can't take in a list initializer, overload for backward compatibility
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* block_relation_onehot_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
// 256-bool LDS staging caps N_k <= 256 (for kN0=64 -> seqlen_k <= 16384).
// Not constexpr because the assert needs runtime evaluation.
CK_TILE_HOST static Kargs MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* block_relation_onehot_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
{
// 256-bool LDS staging caps N_k <= 256 per Q-tile.
// For kN0=64 this means seqlen_k <= 16384.
assert(ck_tile::integer_divide_ceil(seqlen_k, FmhaPipeline::kN0) <= 256 &&
"256-bool LDS staging caps N_k <= 256 (for kN0=64: seqlen_k <= 16384)");
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
@@ -248,7 +256,11 @@ struct FmhaFwdJengaKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
// Extra LDS stages 256 bools (4B-aligned for LDS loads) — caps N_k <= 256 per Q-tile,
// i.e. seqlen_k <= 256 * kN0 (for kN0=64 -> seqlen_k <= 16384). MakeKargs asserts this.
// The extra 1024B is jenga-specific: pipeline (block_fmha_pipeline_qr_ks_vs_async_jenga
// .hpp:261) stages block_relation_onehot here. Do NOT copy this `+ 256*sizeof(int)` to
// other sparse kernels (e.g. VSA) without first wiring a real reader.
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
// if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d",

View File

@@ -251,8 +251,7 @@ struct FmhaFwdVSAKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);

View File

@@ -22,7 +22,7 @@ struct SpargeBlockMapKernel
static constexpr index_t kN0 = Pipeline::kN0;
static constexpr index_t D = Pipeline::D;
static constexpr index_t kAlignment = 16 / sizeof(QDataType);
static constexpr index_t kAlignment = 16 / sizeof(QDataType); // 16B = dwordx4 load width
struct Kargs
{
@@ -52,19 +52,18 @@ struct SpargeBlockMapKernel
void* lut_ptr;
void* valid_block_num_ptr;
// R20 K-stat workspace from Kernel A
const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
// K-block stats workspace produced by SpargeKStatsKernel
const void*
pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
index_t N_k;
// R21A Phase 4: optional per-head topk (size = nhead_q floats).
// nullptr => use scalar `topk` for all heads.
// Per-head topk (size = nhead_q floats). Required (non-null).
const float* topk_per_head;
// R21B: optional per-head cdfthreshd (size = nhead_q floats).
// nullptr => use scalar `cdfthreshd` for all heads.
// Only consulted on topk<=0 path; bench currently always uses topk path.
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
// only consulted on topk<=0 path.
const float* cdfthreshd_per_head;
};
@@ -90,8 +89,8 @@ struct SpargeBlockMapKernel
void* valid_block_num_ptr,
const void* pooled_k_ws_ptr,
const void* sim_k_ws_ptr,
const float* topk_per_head = nullptr,
const float* cdfthreshd_per_head = nullptr)
const float* topk_per_head,
const float* cdfthreshd_per_head)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{q_ptr,
@@ -195,20 +194,15 @@ struct SpargeBlockMapKernel
// Shared memory
__shared__ char smem[Pipeline::GetSmemSize()];
// R20 K-stat workspace: pre-offset for this (b, hk).
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
// K-stat workspace: pre-offset for this (b, hk).
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
const index_t khead_off = (b * nhead_k + hk) * N_k;
const auto* pooled_k_ws =
reinterpret_cast<const float*>(kargs.pooled_k_ws_ptr) + khead_off * D;
const auto* sim_k_ws =
reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
reinterpret_cast<const KDataType*>(kargs.pooled_k_ws_ptr) + khead_off * D;
const auto* sim_k_ws = reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
// R21A Phase 4: per-head topk if provided, else scalar broadcast.
const float topk_eff =
(kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk;
// R21B: per-head cdfthreshd if provided, else scalar broadcast.
const float cdfthreshd_eff =
(kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd;
const float topk_eff = kargs.topk_per_head[hq];
const float cdfthreshd_eff = kargs.cdfthreshd_per_head[hq];
Pipeline{}(q_window,
k_window,

View File

@@ -40,15 +40,13 @@ struct SpargeKStatsKernel
float simthreshd1;
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
void* sim_k_ptr; // [batch, nhead_k, N_k] uint8
index_t N_k;
// R21A Phase 4 + R21B fix: optional per-head simthreshd1.
// Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract
// (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first
// nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`.
// Per-head simthreshd1 pointer (size = nhead_q floats; kernel indexes [hk] only).
// Required (non-null); matches SpargeAttn upstream contract.
const float* simthreshd1_per_head;
};
@@ -62,7 +60,7 @@ struct SpargeKStatsKernel
float simthreshd1,
void* pooled_k_ptr,
void* sim_k_ptr,
const float* simthreshd1_per_head = nullptr)
const float* simthreshd1_per_head)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{k_ptr,
@@ -111,17 +109,15 @@ struct SpargeKStatsKernel
{0, 0},
Pipeline::MakeKBlockDistribution());
const index_t N_k = kargs.N_k;
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
auto* pooled_k_out = reinterpret_cast<float*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
const index_t N_k = kargs.N_k;
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
auto* pooled_k_out =
reinterpret_cast<KDataType*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
__shared__ char smem[Pipeline::GetSmemSize()];
// R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast.
const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr)
? kargs.simthreshd1_per_head[hk]
: kargs.simthreshd1;
const float simthreshd1_eff = kargs.simthreshd1_per_head[hk];
Pipeline{}(k_window,
kargs.seqlen_k,

View File

@@ -262,7 +262,7 @@ struct SpargeBlockMapPipeline
uint8_t* block_map_ptr,
int32_t* lut_ptr,
int32_t* valid_block_num_ptr,
const float* __restrict__ pooled_k_ws_ptr,
const KDataType* __restrict__ pooled_k_ws_ptr,
const uint8_t* __restrict__ sim_k_ws_ptr,
void* smem_ptr) const
{
@@ -356,10 +356,10 @@ struct SpargeBlockMapPipeline
for(index_t kb = 0; kb < N_k; ++kb)
{
const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
float pooled_k_mean[KPerThread];
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_mean[k] = p_kb[k];
pooled_k_mean[k] = type_convert<float>(p_kb[k]);
float dot = 0.f;
for(index_t k = 0; k < KPerThread; ++k)
@@ -417,8 +417,7 @@ struct SpargeBlockMapPipeline
// cdfthreshd path (topk <= 0) still requires normalised scores so the
// accumulator `cumulative_prob` matches probabilities.
const bool topk_active = (topk > 0.f);
const float inv_sum =
(!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
const float inv_sum = (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
if(!topk_active)
{
for(index_t i = tid; i < N_k; i += kBlockSize)

View File

@@ -49,8 +49,8 @@ struct SpargeKStatsPipeline
index_t seqlen_k,
index_t kb,
float simthreshd1,
float* __restrict__ pooled_k_out, // D floats
uint8_t* __restrict__ sim_k_out, // 1 byte
KDataType* __restrict__ pooled_k_out, // D KDataType (fp16/bf16)
uint8_t* __restrict__ sim_k_out, // 1 byte
void* smem_ptr) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
@@ -70,19 +70,19 @@ struct SpargeKStatsPipeline
const index_t m_idx = lane_id / KThreads;
// pooled_k_mean: column sum then cross-warp reduce.
// R21A: drop trailing sync (next cross_warp_reduce has its own leading sync).
// Drop trailing sync (next cross_warp_reduce has its own leading sync).
float pooled_k_mean[KPerThread];
Base::template column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
Base::template column_reduce_cross_warp<false>(pooled_k_mean, smem_reduce);
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_mean[k] *= inv_bs_k;
// R21A: write pooled_k_mean to global early so its register liveness ends here,
// Write pooled_k_mean to global early so its register liveness ends here,
// freeing VGPR before k_sum_hat becomes live.
if(warp_id == 0 && m_idx == 0)
{
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k];
pooled_k_out[k_idx * KPerThread + k] = type_convert<KDataType>(pooled_k_mean[k]);
}
// K row L2 norms + normalised column sum (k_sum_hat)
@@ -91,7 +91,7 @@ struct SpargeKStatsPipeline
float k_sum_hat[KPerThread];
Base::template column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
// R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write).
// Drop trailing sync (no further smem read; only intra-warp shuffle + global write).
Base::template column_reduce_cross_warp<false>(k_sum_hat, smem_reduce);
// sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1