sparse_attn: wire -mask and -attention_sink (block-map prune + attn mask)

This commit is contained in:
Gino Lu
2026-05-19 23:22:00 -04:00
parent b3ea819ff7
commit fb75da2467
5 changed files with 199 additions and 102 deletions

View File

@@ -9,11 +9,11 @@ Implemented:
- top-k / `cdfthreshd` block selection, BlockMap LUT
- sparse FMHA (both `vsa` and `jenga` backends)
- per-head `topk` / `simthreshd1` / `cdfthreshd`
- **is_causal mask in pooled score** (top-left only at block-map grain) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338))
- **attention_sink** — block-map column 0 force-on ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355))
Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)):
- **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53))
- **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338))
- **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355))
- **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345))
- **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371))
@@ -40,7 +40,11 @@ ninja tile_example_sparge
Select a PV-skip variant with `-pv_mode={none|warp|block}` (default `warp`); finite `-pv_threshold=20` lets the per-Q-tile skip predicate fire.
Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.
Mask + attention sink:
- `-mask` accepts the `01_fmha` grammar (`0` / `t` / `b` / `t:l,r` / `xt:N` / `g:y,x`, default `0`). The block-map selection prunes past-diagonal blocks only under `mask_top_left` (`t`); `b` / SWA / generic are forwarded to the attention kernel and emit a stderr WARN that the block-map selection is unchanged.
- `-attention_sink {0,1}` forces block-map column `kb=0` ON for every Q-block (default `0`). Under `-mask t` this is degenerate since `kb=0` is always causal-valid.
Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. When `-mask != 0` or `-attention_sink == 1`, the `[block_map cross-check]` and `[VSA LUT self-consistency]` cells are SKIPPED (the CPU reference does not model causal mask or sink); the `[attention output]` cell still runs but the dense reference applies no mask, so it will report FAIL on the kernel-correct output. Treat `-v=1` correctness as **block-map level only** in those configurations.
## References

View File

@@ -56,6 +56,11 @@ struct sparge_blockmap_args
const float* simthreshd1_per_head_ptr = nullptr;
const float* cdfthreshd_per_head_ptr = nullptr;
const float* topk_per_head_ptr = nullptr;
// R32 Items 2+3. Pipeline only honours mask_enum::mask_top_left; CLI warns
// on other types. Defaults preserve back-compat for callers not yet setting.
mask_enum mask_type = mask_enum::no_mask;
bool attention_sink = false;
};
struct sparge_blockmap_workspace_layout
@@ -105,7 +110,9 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args,
pooled_k_ws_ptr,
sim_k_ws_ptr,
args.topk_per_head_ptr,
args.cdfthreshd_per_head_ptr);
args.cdfthreshd_per_head_ptr,
static_cast<ck_tile::index_t>(args.mask_type),
args.attention_sink);
dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q);
return ck_tile::make_tuple(kargs, grids);

View File

@@ -19,6 +19,7 @@
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "01_fmha/mask.hpp" // R32: mask_info::decode, mask_enum
#include "fmha_fwd_trek.hpp"
#include "sparge_blockmap_trek.hpp"
#include "sparge_tool.hpp"
@@ -126,7 +127,23 @@ auto create_args(int argc, char* argv[])
"none = no skip (kNone binary; matches VSA baseline). "
"warp = per-wavefront butterfly vote (R25 A1; default). "
"block = per-block AND vote via 1 LDS slot + block_sync_lds (R30). "
"Overrides -pv_skip_compile when set explicitly.");
"Overrides -pv_skip_compile when set explicitly.")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, "
"window_size negative is causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, "
"window_size negative is causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size "
"(only debug purpose for now)")
.insert("attention_sink",
"0",
"SpargeAttn: force block-map column 0 ON (kb=0 always selected). "
"0=off, 1=on. Block-map level only; independent of -mask sink prefix.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -161,6 +178,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
int pv_skip_compile = arg_parser.get_int("pv_skip_compile");
std::string pv_per_head_s = arg_parser.get_str("pv_threshold_per_head");
std::string pv_mode_str = arg_parser.get_str("pv_mode");
std::string mask_str = arg_parser.get_str("mask");
bool attention_sink = arg_parser.get_bool("attention_sink");
// R30: --pv_mode maps to the int dispatched at host.
// none -> 0 (kNone), warp -> 1 (kPerWave), block -> 2 (kPerBlock).
@@ -192,6 +211,15 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
if(hdim_v < 0)
hdim_v = hdim_q;
mask_info mask = mask_info::decode(mask_str, seqlen_q, seqlen_k);
if(mask.type != mask_enum::no_mask && mask.type != mask_enum::mask_top_left)
std::fprintf(stderr,
"[test_sparge] WARN: -mask='%s' (type=%d) - block-map only "
"filters mask_top_left; selection will not prune past-diagonal "
"blocks. attention kernel still applies the mask.\n",
mask_str.c_str(),
static_cast<int>(mask.type));
// If cdfthreshd >= 0, use CDF mode; otherwise use topk mode
if(cdfthreshd >= 0.0f)
topk = -1.0f;
@@ -281,6 +309,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
bmap_args.mask_type = mask.type; // R32 Item 2
bmap_args.attention_sink = attention_sink; // R32 Item 3
// K-stats workspace: caller-owned, sized via host helper, allocated once outside any timing.
const size_t ws_bytes = sparge_blockmap_get_workspace_size(bmap_traits, bmap_args);
@@ -350,7 +380,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_traits.hdim_v = hdim_v;
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
attn_traits.mask_type = mask.type;
attn_traits.bm0 = BLKQ;
fmha_jenga_fwd_args attn_args;
@@ -380,9 +410,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
attn_args.window_size_left = mask.left;
attn_args.window_size_right = mask.right;
attn_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
}
@@ -395,7 +425,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_traits.hdim_v = hdim_v;
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
attn_traits.is_v_rowmajor = true;
attn_traits.mask_type = mask_enum::no_mask;
attn_traits.mask_type = mask.type;
attn_traits.bm0 = BLKQ;
fmha_sparge_fwd_args attn_args;
@@ -435,9 +465,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
attn_args.batch_stride_k = k_strides[0];
attn_args.batch_stride_v = v_strides[0];
attn_args.batch_stride_o = o_strides[0];
attn_args.window_size_left = -1;
attn_args.window_size_right = -1;
attn_args.mask_type = 0;
attn_args.window_size_left = mask.left;
attn_args.window_size_right = mask.right;
attn_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
avg_ms =
sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
@@ -500,96 +530,111 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
sparge::SpargeParams sp;
sp.BLKQ = BLKQ;
sp.BLKK = BLKK;
sp.simthreshd1 = simthreshd1;
sp.cdfthreshd = cdfthreshd;
sp.topk = topk;
sp.i_perm = i_perm;
// R32: CPU reference lacks causal mask + attention_sink; skip block_map
// cross-check + VSA LUT self-consistency when either is in effect. The
// attention-output check below still runs (consumes GPU bmap).
const bool skip_cpu_bm_check = (mask.type != mask_enum::no_mask) || attention_sink;
auto block_map_cpu = sparge::build_block_map_meansim<T>(q_host, k_host, sp);
size_t bm_total = block_map_host.mData.size();
size_t bm_mismatch = 0;
size_t shown = 0;
constexpr size_t MAXSHOW = 10;
std::cout << "\n [block_map cross-check] total=" << bm_total;
for(size_t i = 0; i < bm_total; ++i)
bool bm_pass = true;
bool lut_pass = true;
if(!skip_cpu_bm_check)
{
uint8_t g = block_map_host.mData[i];
uint8_t c = block_map_cpu.mData[i];
if(g != c)
{
if(shown < MAXSHOW)
{
size_t k_idx = i % num_k_blocks;
size_t q_idx = (i / num_k_blocks) % num_q_blocks;
size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead;
size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead);
std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx
<< ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g)
<< " cpu=" << int(c);
++shown;
}
++bm_mismatch;
}
}
bool bm_pass = (bm_mismatch == 0);
float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f;
std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total
<< " (" << std::setprecision(4) << bm_ratio << "%) "
<< (bm_pass ? "PASS" : "FAIL");
auto cpu_lut = sparge::block_map_to_vsa_lut_delta<uint8_t>(block_map_cpu);
bool lut_pass = true;
size_t lut_fails = 0;
for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b)
{
for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h)
sparge::SpargeParams sp;
sp.BLKQ = BLKQ;
sp.BLKK = BLKK;
sp.simthreshd1 = simthreshd1;
sp.cdfthreshd = cdfthreshd;
sp.topk = topk;
sp.i_perm = i_perm;
auto block_map_cpu = sparge::build_block_map_meansim<T>(q_host, k_host, sp);
size_t bm_total = block_map_host.mData.size();
size_t bm_mismatch = 0;
size_t shown = 0;
constexpr size_t MAXSHOW = 10;
std::cout << "\n [block_map cross-check] total=" << bm_total;
for(size_t i = 0; i < bm_total; ++i)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb)
uint8_t g = block_map_host.mData[i];
uint8_t c = block_map_cpu.mData[i];
if(g != c)
{
int32_t valid = cpu_lut.valid_block_num(b, h, qb);
int32_t active_count = 0;
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
if(block_map_cpu(b, h, qb, kb))
++active_count;
int32_t recon_kb = 0;
bool delta_ok = true;
for(int32_t i = 0; i < valid; ++i)
if(shown < MAXSHOW)
{
int32_t d = cpu_lut.lut(b, h, qb, i);
if(d < 0)
{
delta_ok = false;
break;
}
recon_kb += d;
if(recon_kb >= num_k_blocks)
{
delta_ok = false;
break;
}
if(!block_map_cpu(b, h, qb, recon_kb))
{
delta_ok = false;
break;
}
size_t k_idx = i % num_k_blocks;
size_t q_idx = (i / num_k_blocks) % num_q_blocks;
size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead;
size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead);
std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx
<< ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g)
<< " cpu=" << int(c);
++shown;
}
if(valid != active_count || !delta_ok)
++bm_mismatch;
}
}
bm_pass = (bm_mismatch == 0);
float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f;
std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total
<< " (" << std::setprecision(4) << bm_ratio << "%) "
<< (bm_pass ? "PASS" : "FAIL");
auto cpu_lut = sparge::block_map_to_vsa_lut_delta<uint8_t>(block_map_cpu);
size_t lut_fails = 0;
for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b)
{
for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb)
{
lut_pass = false;
if(lut_fails < MAXSHOW)
std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb
<< ") valid=" << valid << " active=" << active_count
<< " delta_ok=" << delta_ok;
++lut_fails;
int32_t valid = cpu_lut.valid_block_num(b, h, qb);
int32_t active_count = 0;
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
if(block_map_cpu(b, h, qb, kb))
++active_count;
int32_t recon_kb = 0;
bool delta_ok = true;
for(int32_t i = 0; i < valid; ++i)
{
int32_t d = cpu_lut.lut(b, h, qb, i);
if(d < 0)
{
delta_ok = false;
break;
}
recon_kb += d;
if(recon_kb >= num_k_blocks)
{
delta_ok = false;
break;
}
if(!block_map_cpu(b, h, qb, recon_kb))
{
delta_ok = false;
break;
}
}
if(valid != active_count || !delta_ok)
{
lut_pass = false;
if(lut_fails < MAXSHOW)
std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb
<< ") valid=" << valid << " active=" << active_count
<< " delta_ok=" << delta_ok;
++lut_fails;
}
}
}
}
std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL");
} // end if(!skip_cpu_bm_check)
else
{
std::cout << "\n [block_map cross-check] SKIPPED (mask/sink active; CPU ref lacks)";
std::cout << "\n [VSA LUT self-consistency] SKIPPED";
}
std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL");
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
ck_tile::reference_blocked_attention<T, uint8_t>(

View File

@@ -65,6 +65,12 @@ struct SpargeBlockMapKernel
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
// only consulted on topk<=0 path.
const float* cdfthreshd_per_head;
// R32 Items 2+3. mask_type stored as index_t (not mask_enum) to keep this
// include-tree header independent of example/01_fmha/mask.hpp. Magic
// constant 1 == mask_enum::mask_top_left (01_fmha/mask.hpp:13-19).
index_t mask_type;
bool attention_sink;
};
CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr,
@@ -90,7 +96,9 @@ struct SpargeBlockMapKernel
const void* pooled_k_ws_ptr,
const void* sim_k_ws_ptr,
const float* topk_per_head,
const float* cdfthreshd_per_head)
const float* cdfthreshd_per_head,
index_t mask_type,
bool attention_sink)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{q_ptr,
@@ -117,7 +125,9 @@ struct SpargeBlockMapKernel
sim_k_ws_ptr,
N_k,
topk_per_head,
cdfthreshd_per_head};
cdfthreshd_per_head,
mask_type,
attention_sink};
}
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q)
@@ -220,7 +230,9 @@ struct SpargeBlockMapKernel
valid_out,
pooled_k_ws,
sim_k_ws,
static_cast<void*>(smem));
static_cast<void*>(smem),
kargs.mask_type,
kargs.attention_sink);
}
};

View File

@@ -263,10 +263,16 @@ struct SpargeBlockMapPipeline
int32_t* valid_block_num_ptr,
const KDataType* __restrict__ pooled_k_ws_ptr,
const uint8_t* __restrict__ sim_k_ws_ptr,
void* smem_ptr) const
void* smem_ptr,
index_t mask_type,
bool attention_sink) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
// mask_enum::mask_top_left == 1 (01_fmha/mask.hpp:16). Multiplicative
// form handles BLKQ=64,BLKK=128 (kM0<kN0) and the kM0>=kN0 case.
const bool is_causal_tl = (mask_type == 1);
// K-loop no longer reduces; only Q-stats uses smem_float0.
// smem_float1 slab is allocated for layout compat but unused.
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
@@ -320,13 +326,23 @@ struct SpargeBlockMapPipeline
// Not similar → force all K blocks ON, early exit
if(!sim_q)
{
// R32 Item 2: only fill causal-valid prefix when active.
const index_t causal_kb_end =
is_causal_tl ? min(N_k, integer_divide_ceil((qb + 1) * kM0, kN0)) : N_k;
for(index_t i = tid; i < N_k; i += kBlockSize)
block_map_ptr[i] = 1;
block_map_ptr[i] = (i < causal_kb_end) ? 1 : 0;
// R32 Item 3: sink force. Under top-left causal, kb=0 always
// causal-valid for qb>=0 -> no-op; meaningful for mask=no + sink=1.
if(attention_sink && tid == 0)
block_map_ptr[0] = 1;
__syncthreads(); // sink visible to LUT-build below
if(lut_ptr != nullptr && tid == 0)
{
int32_t valid = 0, prev = 0;
for(index_t kb = 0; kb < N_k; ++kb)
for(index_t kb = 0; kb < causal_kb_end; ++kb)
{
lut_ptr[valid] = static_cast<int32_t>(kb) - prev;
prev = static_cast<int32_t>(kb);
@@ -354,6 +370,10 @@ struct SpargeBlockMapPipeline
for(index_t kb = 0; kb < N_k; ++kb)
{
// R32 Item 2: top-left causal at block grain.
// (qb,kb) past-diagonal iff kb*kN0 >= (qb+1)*kM0.
const bool causal_killed = is_causal_tl && (kb * kN0 >= (qb + 1) * kM0);
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
float pooled_k_mean[KPerThread];
for(index_t k = 0; k < KPerThread; ++k)
@@ -372,17 +392,17 @@ struct SpargeBlockMapPipeline
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
// AND have score = -inf so the selection step (topk / cdf) does NOT
// pick them again (would double-count toward topk budget).
// Both writes MUST stay together. Any selection rewrite
// (e.g. iterative argmax -> bitonic sort) must keep the -inf write.
if(!sim_k)
// R32: causal_killed gates the force-on so past-diagonal blocks are
// NOT forced ON; bmap stays 0, scores -inf so selection excludes them.
if(causal_killed)
smem_scores[kb] = -numeric<float>::infinity(); // bmap stays 0
else if(!sim_k)
{
smem_bmap[kb] = 1;
smem_scores[kb] = -numeric<float>::infinity();
}
else
{
smem_scores[kb] = dot * scale;
}
}
}
__syncthreads(); // guard selection's reads of smem_bmap / smem_scores
@@ -517,6 +537,15 @@ struct SpargeBlockMapPipeline
// ==================================================================
// Write outputs to global memory
// ==================================================================
// R32 Item 3: force smem_bmap[0]=1 BEFORE LUT collation reads it.
// Reuses existing LUT-build loop (R31 §4: don't manually insert into
// delta stream). Causal post-multiply unnecessary: D.2 sets killed
// scores to -inf; selection gate L490 `bv > 0` excludes them, so
// smem_bmap[bi]=1 never fires for killed blocks.
if(attention_sink && tid == 0)
smem_bmap[0] = 1;
__syncthreads();
for(index_t i = tid; i < N_k; i += kBlockSize)
block_map_ptr[i] = smem_bmap[i];