mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
sparse_attn: wire -mask and -attention_sink (block-map prune + attn mask)
This commit is contained in:
@@ -65,6 +65,12 @@ struct SpargeBlockMapKernel
|
||||
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
|
||||
// only consulted on topk<=0 path.
|
||||
const float* cdfthreshd_per_head;
|
||||
|
||||
// R32 Items 2+3. mask_type stored as index_t (not mask_enum) to keep this
|
||||
// include-tree header independent of example/01_fmha/mask.hpp. Magic
|
||||
// constant 1 == mask_enum::mask_top_left (01_fmha/mask.hpp:13-19).
|
||||
index_t mask_type;
|
||||
bool attention_sink;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr,
|
||||
@@ -90,7 +96,9 @@ struct SpargeBlockMapKernel
|
||||
const void* pooled_k_ws_ptr,
|
||||
const void* sim_k_ws_ptr,
|
||||
const float* topk_per_head,
|
||||
const float* cdfthreshd_per_head)
|
||||
const float* cdfthreshd_per_head,
|
||||
index_t mask_type,
|
||||
bool attention_sink)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{q_ptr,
|
||||
@@ -117,7 +125,9 @@ struct SpargeBlockMapKernel
|
||||
sim_k_ws_ptr,
|
||||
N_k,
|
||||
topk_per_head,
|
||||
cdfthreshd_per_head};
|
||||
cdfthreshd_per_head,
|
||||
mask_type,
|
||||
attention_sink};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q)
|
||||
@@ -220,7 +230,9 @@ struct SpargeBlockMapKernel
|
||||
valid_out,
|
||||
pooled_k_ws,
|
||||
sim_k_ws,
|
||||
static_cast<void*>(smem));
|
||||
static_cast<void*>(smem),
|
||||
kargs.mask_type,
|
||||
kargs.attention_sink);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -263,10 +263,16 @@ struct SpargeBlockMapPipeline
|
||||
int32_t* valid_block_num_ptr,
|
||||
const KDataType* __restrict__ pooled_k_ws_ptr,
|
||||
const uint8_t* __restrict__ sim_k_ws_ptr,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
index_t mask_type,
|
||||
bool attention_sink) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
|
||||
// mask_enum::mask_top_left == 1 (01_fmha/mask.hpp:16). Multiplicative
|
||||
// form handles BLKQ=64,BLKK=128 (kM0<kN0) and the kM0>=kN0 case.
|
||||
const bool is_causal_tl = (mask_type == 1);
|
||||
|
||||
// K-loop no longer reduces; only Q-stats uses smem_float0.
|
||||
// smem_float1 slab is allocated for layout compat but unused.
|
||||
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
|
||||
@@ -320,13 +326,23 @@ struct SpargeBlockMapPipeline
|
||||
// Not similar → force all K blocks ON, early exit
|
||||
if(!sim_q)
|
||||
{
|
||||
// R32 Item 2: only fill causal-valid prefix when active.
|
||||
const index_t causal_kb_end =
|
||||
is_causal_tl ? min(N_k, integer_divide_ceil((qb + 1) * kM0, kN0)) : N_k;
|
||||
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
block_map_ptr[i] = 1;
|
||||
block_map_ptr[i] = (i < causal_kb_end) ? 1 : 0;
|
||||
|
||||
// R32 Item 3: sink force. Under top-left causal, kb=0 always
|
||||
// causal-valid for qb>=0 -> no-op; meaningful for mask=no + sink=1.
|
||||
if(attention_sink && tid == 0)
|
||||
block_map_ptr[0] = 1;
|
||||
__syncthreads(); // sink visible to LUT-build below
|
||||
|
||||
if(lut_ptr != nullptr && tid == 0)
|
||||
{
|
||||
int32_t valid = 0, prev = 0;
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
for(index_t kb = 0; kb < causal_kb_end; ++kb)
|
||||
{
|
||||
lut_ptr[valid] = static_cast<int32_t>(kb) - prev;
|
||||
prev = static_cast<int32_t>(kb);
|
||||
@@ -354,6 +370,10 @@ struct SpargeBlockMapPipeline
|
||||
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
// R32 Item 2: top-left causal at block grain.
|
||||
// (qb,kb) past-diagonal iff kb*kN0 >= (qb+1)*kM0.
|
||||
const bool causal_killed = is_causal_tl && (kb * kN0 >= (qb + 1) * kM0);
|
||||
|
||||
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
|
||||
float pooled_k_mean[KPerThread];
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
@@ -372,17 +392,17 @@ struct SpargeBlockMapPipeline
|
||||
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
|
||||
// AND have score = -inf so the selection step (topk / cdf) does NOT
|
||||
// pick them again (would double-count toward topk budget).
|
||||
// Both writes MUST stay together. Any selection rewrite
|
||||
// (e.g. iterative argmax -> bitonic sort) must keep the -inf write.
|
||||
if(!sim_k)
|
||||
// R32: causal_killed gates the force-on so past-diagonal blocks are
|
||||
// NOT forced ON; bmap stays 0, scores -inf so selection excludes them.
|
||||
if(causal_killed)
|
||||
smem_scores[kb] = -numeric<float>::infinity(); // bmap stays 0
|
||||
else if(!sim_k)
|
||||
{
|
||||
smem_bmap[kb] = 1;
|
||||
smem_scores[kb] = -numeric<float>::infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_scores[kb] = dot * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // guard selection's reads of smem_bmap / smem_scores
|
||||
@@ -517,6 +537,15 @@ struct SpargeBlockMapPipeline
|
||||
// ==================================================================
|
||||
// Write outputs to global memory
|
||||
// ==================================================================
|
||||
// R32 Item 3: force smem_bmap[0]=1 BEFORE LUT collation reads it.
|
||||
// Reuses existing LUT-build loop (R31 §4: don't manually insert into
|
||||
// delta stream). Causal post-multiply unnecessary: D.2 sets killed
|
||||
// scores to -inf; selection gate L490 `bv > 0` excludes them, so
|
||||
// smem_bmap[bi]=1 never fires for killed blocks.
|
||||
if(attention_sink && tid == 0)
|
||||
smem_bmap[0] = 1;
|
||||
__syncthreads();
|
||||
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
block_map_ptr[i] = smem_bmap[i];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user