sparse_attn: wire -mask and -attention_sink (block-map prune + attn mask)

This commit is contained in:
Gino Lu
2026-05-19 23:22:00 -04:00
parent b3ea819ff7
commit fb75da2467
5 changed files with 199 additions and 102 deletions

View File

@@ -65,6 +65,12 @@ struct SpargeBlockMapKernel
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
// only consulted on topk<=0 path.
const float* cdfthreshd_per_head;
// R32 Items 2+3. mask_type stored as index_t (not mask_enum) to keep this
// include-tree header independent of example/01_fmha/mask.hpp. Magic
// constant 1 == mask_enum::mask_top_left (01_fmha/mask.hpp:13-19).
index_t mask_type;
bool attention_sink;
};
CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr,
@@ -90,7 +96,9 @@ struct SpargeBlockMapKernel
const void* pooled_k_ws_ptr,
const void* sim_k_ws_ptr,
const float* topk_per_head,
const float* cdfthreshd_per_head)
const float* cdfthreshd_per_head,
index_t mask_type,
bool attention_sink)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{q_ptr,
@@ -117,7 +125,9 @@ struct SpargeBlockMapKernel
sim_k_ws_ptr,
N_k,
topk_per_head,
cdfthreshd_per_head};
cdfthreshd_per_head,
mask_type,
attention_sink};
}
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q)
@@ -220,7 +230,9 @@ struct SpargeBlockMapKernel
valid_out,
pooled_k_ws,
sim_k_ws,
static_cast<void*>(smem));
static_cast<void*>(smem),
kargs.mask_type,
kargs.attention_sink);
}
};

View File

@@ -263,10 +263,16 @@ struct SpargeBlockMapPipeline
int32_t* valid_block_num_ptr,
const KDataType* __restrict__ pooled_k_ws_ptr,
const uint8_t* __restrict__ sim_k_ws_ptr,
void* smem_ptr) const
void* smem_ptr,
index_t mask_type,
bool attention_sink) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
// mask_enum::mask_top_left == 1 (01_fmha/mask.hpp:16). Multiplicative
// form handles BLKQ=64,BLKK=128 (kM0<kN0) and the kM0>=kN0 case.
const bool is_causal_tl = (mask_type == 1);
// K-loop no longer reduces; only Q-stats uses smem_float0.
// smem_float1 slab is allocated for layout compat but unused.
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
@@ -320,13 +326,23 @@ struct SpargeBlockMapPipeline
// Not similar → force all K blocks ON, early exit
if(!sim_q)
{
// R32 Item 2: only fill causal-valid prefix when active.
const index_t causal_kb_end =
is_causal_tl ? min(N_k, integer_divide_ceil((qb + 1) * kM0, kN0)) : N_k;
for(index_t i = tid; i < N_k; i += kBlockSize)
block_map_ptr[i] = 1;
block_map_ptr[i] = (i < causal_kb_end) ? 1 : 0;
// R32 Item 3: sink force. Under top-left causal, kb=0 always
// causal-valid for qb>=0 -> no-op; meaningful for mask=no + sink=1.
if(attention_sink && tid == 0)
block_map_ptr[0] = 1;
__syncthreads(); // sink visible to LUT-build below
if(lut_ptr != nullptr && tid == 0)
{
int32_t valid = 0, prev = 0;
for(index_t kb = 0; kb < N_k; ++kb)
for(index_t kb = 0; kb < causal_kb_end; ++kb)
{
lut_ptr[valid] = static_cast<int32_t>(kb) - prev;
prev = static_cast<int32_t>(kb);
@@ -354,6 +370,10 @@ struct SpargeBlockMapPipeline
for(index_t kb = 0; kb < N_k; ++kb)
{
// R32 Item 2: top-left causal at block grain.
// (qb,kb) past-diagonal iff kb*kN0 >= (qb+1)*kM0.
const bool causal_killed = is_causal_tl && (kb * kN0 >= (qb + 1) * kM0);
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
float pooled_k_mean[KPerThread];
for(index_t k = 0; k < KPerThread; ++k)
@@ -372,17 +392,17 @@ struct SpargeBlockMapPipeline
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
// AND have score = -inf so the selection step (topk / cdf) does NOT
// pick them again (would double-count toward topk budget).
// Both writes MUST stay together. Any selection rewrite
// (e.g. iterative argmax -> bitonic sort) must keep the -inf write.
if(!sim_k)
// R32: causal_killed gates the force-on so past-diagonal blocks are
// NOT forced ON; bmap stays 0, scores -inf so selection excludes them.
if(causal_killed)
smem_scores[kb] = -numeric<float>::infinity(); // bmap stays 0
else if(!sim_k)
{
smem_bmap[kb] = 1;
smem_scores[kb] = -numeric<float>::infinity();
}
else
{
smem_scores[kb] = dot * scale;
}
}
}
__syncthreads(); // guard selection's reads of smem_bmap / smem_scores
@@ -517,6 +537,15 @@ struct SpargeBlockMapPipeline
// ==================================================================
// Write outputs to global memory
// ==================================================================
// R32 Item 3: force smem_bmap[0]=1 BEFORE LUT collation reads it.
// Reuses existing LUT-build loop (R31 §4: don't manually insert into
// delta stream). Causal post-multiply unnecessary: D.2 sets killed
// scores to -inf; selection gate L490 `bv > 0` excludes them, so
// smem_bmap[bi]=1 never fires for killed blocks.
if(attention_sink && tid == 0)
smem_bmap[0] = 1;
__syncthreads();
for(index_t i = tid; i < N_k; i += kBlockSize)
block_map_ptr[i] = smem_bmap[i];