mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
refactor(sparse_attn): caller-owned workspace + dtype-aware sizing
Replace process-lifetime lazy hipMalloc K-stats workspace with a caller-owned buffer; expose sparge_blockmap_get_workspace_size() / compute_workspace_layout() host helpers. Split the combined sparge_blockmap_fwd into stage launchers (sparge_kstats_fwd_oneshot + sparge_blockmap_only_fwd_oneshot) so the chained launch is timed end-to-end. Make pooled_k storage dtype follow KDataType (fp16/bf16) instead of fp32 to halve workspace footprint and match dense-FMHA precision. Tighten per-head superparam pointers to required (non-null) and assert N_k <= 256 in jenga MakeKargs to document the 256-bool LDS staging cap. Drop the obsolete VSA extra-LDS staging. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
@@ -133,34 +134,41 @@ struct FmhaFwdJengaKernel
|
||||
};
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* block_relation_onehot_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
// 256-bool LDS staging caps N_k <= 256 (for kN0=64 -> seqlen_k <= 16384).
|
||||
// Not constexpr because the assert needs runtime evaluation.
|
||||
CK_TILE_HOST static Kargs MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* block_relation_onehot_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
float scale_s,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
// 256-bool LDS staging caps N_k <= 256 per Q-tile.
|
||||
// For kN0=64 this means seqlen_k <= 16384.
|
||||
assert(ck_tile::integer_divide_ceil(seqlen_k, FmhaPipeline::kN0) <= 256 &&
|
||||
"256-bool LDS staging caps N_k <= 256 (for kN0=64: seqlen_k <= 16384)");
|
||||
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
@@ -248,7 +256,11 @@ struct FmhaFwdJengaKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
|
||||
// Extra LDS stages 256 bools (4B-aligned for LDS loads) — caps N_k <= 256 per Q-tile,
|
||||
// i.e. seqlen_k <= 256 * kN0 (for kN0=64 -> seqlen_k <= 16384). MakeKargs asserts this.
|
||||
// The extra 1024B is jenga-specific: pipeline (block_fmha_pipeline_qr_ks_vs_async_jenga
|
||||
// .hpp:261) stages block_relation_onehot here. Do NOT copy this `+ 256*sizeof(int)` to
|
||||
// other sparse kernels (e.g. VSA) without first wiring a real reader.
|
||||
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
|
||||
|
||||
// if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d",
|
||||
|
||||
@@ -251,8 +251,7 @@ struct FmhaFwdVSAKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
// Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads.
|
||||
__shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
@@ -22,7 +22,7 @@ struct SpargeBlockMapKernel
|
||||
static constexpr index_t kN0 = Pipeline::kN0;
|
||||
static constexpr index_t D = Pipeline::D;
|
||||
|
||||
static constexpr index_t kAlignment = 16 / sizeof(QDataType);
|
||||
static constexpr index_t kAlignment = 16 / sizeof(QDataType); // 16B = dwordx4 load width
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
@@ -52,19 +52,18 @@ struct SpargeBlockMapKernel
|
||||
void* lut_ptr;
|
||||
void* valid_block_num_ptr;
|
||||
|
||||
// R20 K-stat workspace from Kernel A
|
||||
const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32
|
||||
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
|
||||
// K-block stats workspace produced by SpargeKStatsKernel
|
||||
const void*
|
||||
pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
|
||||
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
|
||||
|
||||
index_t N_k;
|
||||
|
||||
// R21A Phase 4: optional per-head topk (size = nhead_q floats).
|
||||
// nullptr => use scalar `topk` for all heads.
|
||||
// Per-head topk (size = nhead_q floats). Required (non-null).
|
||||
const float* topk_per_head;
|
||||
|
||||
// R21B: optional per-head cdfthreshd (size = nhead_q floats).
|
||||
// nullptr => use scalar `cdfthreshd` for all heads.
|
||||
// Only consulted on topk<=0 path; bench currently always uses topk path.
|
||||
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
|
||||
// only consulted on topk<=0 path.
|
||||
const float* cdfthreshd_per_head;
|
||||
};
|
||||
|
||||
@@ -90,8 +89,8 @@ struct SpargeBlockMapKernel
|
||||
void* valid_block_num_ptr,
|
||||
const void* pooled_k_ws_ptr,
|
||||
const void* sim_k_ws_ptr,
|
||||
const float* topk_per_head = nullptr,
|
||||
const float* cdfthreshd_per_head = nullptr)
|
||||
const float* topk_per_head,
|
||||
const float* cdfthreshd_per_head)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{q_ptr,
|
||||
@@ -195,20 +194,15 @@ struct SpargeBlockMapKernel
|
||||
// Shared memory
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
|
||||
// R20 K-stat workspace: pre-offset for this (b, hk).
|
||||
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
|
||||
// K-stat workspace: pre-offset for this (b, hk).
|
||||
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
|
||||
const index_t khead_off = (b * nhead_k + hk) * N_k;
|
||||
const auto* pooled_k_ws =
|
||||
reinterpret_cast<const float*>(kargs.pooled_k_ws_ptr) + khead_off * D;
|
||||
const auto* sim_k_ws =
|
||||
reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
|
||||
reinterpret_cast<const KDataType*>(kargs.pooled_k_ws_ptr) + khead_off * D;
|
||||
const auto* sim_k_ws = reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
|
||||
|
||||
// R21A Phase 4: per-head topk if provided, else scalar broadcast.
|
||||
const float topk_eff =
|
||||
(kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk;
|
||||
// R21B: per-head cdfthreshd if provided, else scalar broadcast.
|
||||
const float cdfthreshd_eff =
|
||||
(kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd;
|
||||
const float topk_eff = kargs.topk_per_head[hq];
|
||||
const float cdfthreshd_eff = kargs.cdfthreshd_per_head[hq];
|
||||
|
||||
Pipeline{}(q_window,
|
||||
k_window,
|
||||
|
||||
@@ -40,15 +40,13 @@ struct SpargeKStatsKernel
|
||||
|
||||
float simthreshd1;
|
||||
|
||||
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32
|
||||
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype)
|
||||
void* sim_k_ptr; // [batch, nhead_k, N_k] uint8
|
||||
|
||||
index_t N_k;
|
||||
|
||||
// R21A Phase 4 + R21B fix: optional per-head simthreshd1.
|
||||
// Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract
|
||||
// (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first
|
||||
// nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`.
|
||||
// Per-head simthreshd1 pointer (size = nhead_q floats; kernel indexes [hk] only).
|
||||
// Required (non-null); matches SpargeAttn upstream contract.
|
||||
const float* simthreshd1_per_head;
|
||||
};
|
||||
|
||||
@@ -62,7 +60,7 @@ struct SpargeKStatsKernel
|
||||
float simthreshd1,
|
||||
void* pooled_k_ptr,
|
||||
void* sim_k_ptr,
|
||||
const float* simthreshd1_per_head = nullptr)
|
||||
const float* simthreshd1_per_head)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{k_ptr,
|
||||
@@ -111,17 +109,15 @@ struct SpargeKStatsKernel
|
||||
{0, 0},
|
||||
Pipeline::MakeKBlockDistribution());
|
||||
|
||||
const index_t N_k = kargs.N_k;
|
||||
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
|
||||
auto* pooled_k_out = reinterpret_cast<float*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
|
||||
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
|
||||
const index_t N_k = kargs.N_k;
|
||||
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
|
||||
auto* pooled_k_out =
|
||||
reinterpret_cast<KDataType*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
|
||||
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
|
||||
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
|
||||
// R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast.
|
||||
const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr)
|
||||
? kargs.simthreshd1_per_head[hk]
|
||||
: kargs.simthreshd1;
|
||||
const float simthreshd1_eff = kargs.simthreshd1_per_head[hk];
|
||||
|
||||
Pipeline{}(k_window,
|
||||
kargs.seqlen_k,
|
||||
|
||||
@@ -262,7 +262,7 @@ struct SpargeBlockMapPipeline
|
||||
uint8_t* block_map_ptr,
|
||||
int32_t* lut_ptr,
|
||||
int32_t* valid_block_num_ptr,
|
||||
const float* __restrict__ pooled_k_ws_ptr,
|
||||
const KDataType* __restrict__ pooled_k_ws_ptr,
|
||||
const uint8_t* __restrict__ sim_k_ws_ptr,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -356,10 +356,10 @@ struct SpargeBlockMapPipeline
|
||||
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
|
||||
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
|
||||
float pooled_k_mean[KPerThread];
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_mean[k] = p_kb[k];
|
||||
pooled_k_mean[k] = type_convert<float>(p_kb[k]);
|
||||
|
||||
float dot = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
@@ -417,8 +417,7 @@ struct SpargeBlockMapPipeline
|
||||
// cdfthreshd path (topk <= 0) still requires normalised scores so the
|
||||
// accumulator `cumulative_prob` matches probabilities.
|
||||
const bool topk_active = (topk > 0.f);
|
||||
const float inv_sum =
|
||||
(!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
|
||||
const float inv_sum = (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
|
||||
if(!topk_active)
|
||||
{
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
|
||||
@@ -49,8 +49,8 @@ struct SpargeKStatsPipeline
|
||||
index_t seqlen_k,
|
||||
index_t kb,
|
||||
float simthreshd1,
|
||||
float* __restrict__ pooled_k_out, // D floats
|
||||
uint8_t* __restrict__ sim_k_out, // 1 byte
|
||||
KDataType* __restrict__ pooled_k_out, // D KDataType (fp16/bf16)
|
||||
uint8_t* __restrict__ sim_k_out, // 1 byte
|
||||
void* smem_ptr) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
@@ -70,19 +70,19 @@ struct SpargeKStatsPipeline
|
||||
const index_t m_idx = lane_id / KThreads;
|
||||
|
||||
// pooled_k_mean: column sum then cross-warp reduce.
|
||||
// R21A: drop trailing sync (next cross_warp_reduce has its own leading sync).
|
||||
// Drop trailing sync (next cross_warp_reduce has its own leading sync).
|
||||
float pooled_k_mean[KPerThread];
|
||||
Base::template column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
|
||||
Base::template column_reduce_cross_warp<false>(pooled_k_mean, smem_reduce);
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_mean[k] *= inv_bs_k;
|
||||
|
||||
// R21A: write pooled_k_mean to global early so its register liveness ends here,
|
||||
// Write pooled_k_mean to global early so its register liveness ends here,
|
||||
// freeing VGPR before k_sum_hat becomes live.
|
||||
if(warp_id == 0 && m_idx == 0)
|
||||
{
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k];
|
||||
pooled_k_out[k_idx * KPerThread + k] = type_convert<KDataType>(pooled_k_mean[k]);
|
||||
}
|
||||
|
||||
// K row L2 norms + normalised column sum (k_sum_hat)
|
||||
@@ -91,7 +91,7 @@ struct SpargeKStatsPipeline
|
||||
|
||||
float k_sum_hat[KPerThread];
|
||||
Base::template column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
|
||||
// R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write).
|
||||
// Drop trailing sync (no further smem read; only intra-warp shuffle + global write).
|
||||
Base::template column_reduce_cross_warp<false>(k_sum_hat, smem_reduce);
|
||||
|
||||
// sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1
|
||||
|
||||
Reference in New Issue
Block a user