From fb75da24670edb56768fbf6c5ec0f8ed8c3aa03c Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Tue, 19 May 2026 23:22:00 -0400 Subject: [PATCH] sparse_attn: wire -mask and -attention_sink (block-map prune + attn mask) --- example/ck_tile/50_sparse_attn/README.md | 10 +- .../50_sparse_attn/sparge_blockmap_trek.hpp | 9 +- .../ck_tile/50_sparse_attn/test_sparge.cpp | 219 +++++++++++------- .../kernel/sparge_blockmap_kernel.hpp | 18 +- .../pipeline/sparge_blockmap_pipeline.hpp | 45 +++- 5 files changed, 199 insertions(+), 102 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/README.md b/example/ck_tile/50_sparse_attn/README.md index 0a7b513748..593f4a85ef 100644 --- a/example/ck_tile/50_sparse_attn/README.md +++ b/example/ck_tile/50_sparse_attn/README.md @@ -9,11 +9,11 @@ Implemented: - top-k / `cdfthreshd` block selection, BlockMap LUT - sparse FMHA (both `vsa` and `jenga` backends) - per-head `topk` / `simthreshd1` / `cdfthreshd` +- **is_causal mask in pooled score** (top-left only at block-map grain) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338)) +- **attention_sink** — block-map column 0 force-on ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355)) Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)): - **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53)) -- **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338)) -- **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355)) - **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345)) - **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371)) @@ -40,7 +40,11 @@ ninja tile_example_sparge Select a PV-skip variant with `-pv_mode={none|warp|block}` (default `warp`); finite `-pv_threshold=20` lets the per-Q-tile skip predicate fire. -Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. +Mask + attention sink: +- `-mask` accepts the `01_fmha` grammar (`0` / `t` / `b` / `t:l,r` / `xt:N` / `g:y,x`, default `0`). The block-map selection prunes past-diagonal blocks only under `mask_top_left` (`t`); `b` / SWA / generic are forwarded to the attention kernel and emit a stderr WARN that the block-map selection is unchanged. +- `-attention_sink {0,1}` forces block-map column `kb=0` ON for every Q-block (default `0`). Under `-mask t` this is degenerate since `kb=0` is always causal-valid. + +Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. When `-mask != 0` or `-attention_sink == 1`, the `[block_map cross-check]` and `[VSA LUT self-consistency]` cells are SKIPPED (the CPU reference does not model causal mask or sink); the `[attention output]` cell still runs but the dense reference applies no mask, so it will report FAIL on the kernel-correct output. Treat `-v=1` correctness as **block-map level only** in those configurations. ## References diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp index c0178f70de..7591b94e54 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -56,6 +56,11 @@ struct sparge_blockmap_args const float* simthreshd1_per_head_ptr = nullptr; const float* cdfthreshd_per_head_ptr = nullptr; const float* topk_per_head_ptr = nullptr; + + // R32 Items 2+3. Pipeline only honours mask_enum::mask_top_left; CLI warns + // on other types. Defaults preserve back-compat for callers not yet setting. + mask_enum mask_type = mask_enum::no_mask; + bool attention_sink = false; }; struct sparge_blockmap_workspace_layout @@ -105,7 +110,9 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args, pooled_k_ws_ptr, sim_k_ws_ptr, args.topk_per_head_ptr, - args.cdfthreshd_per_head_ptr); + args.cdfthreshd_per_head_ptr, + static_cast(args.mask_type), + args.attention_sink); dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q); return ck_tile::make_tuple(kargs, grids); diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index 8ba1b97e84..c368cb0304 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -19,6 +19,7 @@ #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "01_fmha/mask.hpp" // R32: mask_info::decode, mask_enum #include "fmha_fwd_trek.hpp" #include "sparge_blockmap_trek.hpp" #include "sparge_tool.hpp" @@ -126,7 +127,23 @@ auto create_args(int argc, char* argv[]) "none = no skip (kNone binary; matches VSA baseline). " "warp = per-wavefront butterfly vote (R25 A1; default). " "block = per-block AND vote via 1 LDS slot + block_sync_lds (R30). " - "Overrides -pv_skip_compile when set explicitly."); + "Overrides -pv_skip_compile when set explicitly.") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, " + "window_size negative is causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, " + "window_size negative is causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size " + "(only debug purpose for now)") + .insert("attention_sink", + "0", + "SpargeAttn: force block-map column 0 ON (kb=0 always selected). " + "0=off, 1=on. Block-map level only; independent of -mask sink prefix."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -161,6 +178,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) int pv_skip_compile = arg_parser.get_int("pv_skip_compile"); std::string pv_per_head_s = arg_parser.get_str("pv_threshold_per_head"); std::string pv_mode_str = arg_parser.get_str("pv_mode"); + std::string mask_str = arg_parser.get_str("mask"); + bool attention_sink = arg_parser.get_bool("attention_sink"); // R30: --pv_mode maps to the int dispatched at host. // none -> 0 (kNone), warp -> 1 (kPerWave), block -> 2 (kPerBlock). @@ -192,6 +211,15 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(hdim_v < 0) hdim_v = hdim_q; + mask_info mask = mask_info::decode(mask_str, seqlen_q, seqlen_k); + if(mask.type != mask_enum::no_mask && mask.type != mask_enum::mask_top_left) + std::fprintf(stderr, + "[test_sparge] WARN: -mask='%s' (type=%d) - block-map only " + "filters mask_top_left; selection will not prune past-diagonal " + "blocks. attention kernel still applies the mask.\n", + mask_str.c_str(), + static_cast(mask.type)); + // If cdfthreshd >= 0, use CDF mode; otherwise use topk mode if(cdfthreshd >= 0.0f) topk = -1.0f; @@ -281,6 +309,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer(); bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr; + bmap_args.mask_type = mask.type; // R32 Item 2 + bmap_args.attention_sink = attention_sink; // R32 Item 3 // K-stats workspace: caller-owned, sized via host helper, allocated once outside any timing. const size_t ws_bytes = sparge_blockmap_get_workspace_size(bmap_traits, bmap_args); @@ -350,7 +380,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.hdim_v = hdim_v; attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; attn_traits.is_v_rowmajor = true; - attn_traits.mask_type = mask_enum::no_mask; + attn_traits.mask_type = mask.type; attn_traits.bm0 = BLKQ; fmha_jenga_fwd_args attn_args; @@ -380,9 +410,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_args.batch_stride_k = k_strides[0]; attn_args.batch_stride_v = v_strides[0]; attn_args.batch_stride_o = o_strides[0]; - attn_args.window_size_left = -1; - attn_args.window_size_right = -1; - attn_args.mask_type = 0; + attn_args.window_size_left = mask.left; + attn_args.window_size_right = mask.right; + attn_args.mask_type = static_cast(mask.type); avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); } @@ -395,7 +425,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_traits.hdim_v = hdim_v; attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; attn_traits.is_v_rowmajor = true; - attn_traits.mask_type = mask_enum::no_mask; + attn_traits.mask_type = mask.type; attn_traits.bm0 = BLKQ; fmha_sparge_fwd_args attn_args; @@ -435,9 +465,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser) attn_args.batch_stride_k = k_strides[0]; attn_args.batch_stride_v = v_strides[0]; attn_args.batch_stride_o = o_strides[0]; - attn_args.window_size_left = -1; - attn_args.window_size_right = -1; - attn_args.mask_type = 0; + attn_args.window_size_left = mask.left; + attn_args.window_size_right = mask.right; + attn_args.mask_type = static_cast(mask.type); avg_ms = sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); @@ -500,96 +530,111 @@ bool run_test(const ck_tile::ArgParser& arg_parser) auto k_ref = to_bhsd(k_host, i_perm); auto v_ref = to_bhsd(v_host, i_perm); - sparge::SpargeParams sp; - sp.BLKQ = BLKQ; - sp.BLKK = BLKK; - sp.simthreshd1 = simthreshd1; - sp.cdfthreshd = cdfthreshd; - sp.topk = topk; - sp.i_perm = i_perm; + // R32: CPU reference lacks causal mask + attention_sink; skip block_map + // cross-check + VSA LUT self-consistency when either is in effect. The + // attention-output check below still runs (consumes GPU bmap). + const bool skip_cpu_bm_check = (mask.type != mask_enum::no_mask) || attention_sink; - auto block_map_cpu = sparge::build_block_map_meansim(q_host, k_host, sp); - - size_t bm_total = block_map_host.mData.size(); - size_t bm_mismatch = 0; - size_t shown = 0; - constexpr size_t MAXSHOW = 10; - std::cout << "\n [block_map cross-check] total=" << bm_total; - for(size_t i = 0; i < bm_total; ++i) + bool bm_pass = true; + bool lut_pass = true; + if(!skip_cpu_bm_check) { - uint8_t g = block_map_host.mData[i]; - uint8_t c = block_map_cpu.mData[i]; - if(g != c) - { - if(shown < MAXSHOW) - { - size_t k_idx = i % num_k_blocks; - size_t q_idx = (i / num_k_blocks) % num_q_blocks; - size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead; - size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead); - std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx - << ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g) - << " cpu=" << int(c); - ++shown; - } - ++bm_mismatch; - } - } - bool bm_pass = (bm_mismatch == 0); - float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f; - std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total - << " (" << std::setprecision(4) << bm_ratio << "%) " - << (bm_pass ? "PASS" : "FAIL"); - auto cpu_lut = sparge::block_map_to_vsa_lut_delta(block_map_cpu); - bool lut_pass = true; - size_t lut_fails = 0; - for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b) - { - for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h) + sparge::SpargeParams sp; + sp.BLKQ = BLKQ; + sp.BLKK = BLKK; + sp.simthreshd1 = simthreshd1; + sp.cdfthreshd = cdfthreshd; + sp.topk = topk; + sp.i_perm = i_perm; + + auto block_map_cpu = sparge::build_block_map_meansim(q_host, k_host, sp); + + size_t bm_total = block_map_host.mData.size(); + size_t bm_mismatch = 0; + size_t shown = 0; + constexpr size_t MAXSHOW = 10; + std::cout << "\n [block_map cross-check] total=" << bm_total; + for(size_t i = 0; i < bm_total; ++i) { - for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb) + uint8_t g = block_map_host.mData[i]; + uint8_t c = block_map_cpu.mData[i]; + if(g != c) { - int32_t valid = cpu_lut.valid_block_num(b, h, qb); - int32_t active_count = 0; - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - if(block_map_cpu(b, h, qb, kb)) - ++active_count; - int32_t recon_kb = 0; - bool delta_ok = true; - for(int32_t i = 0; i < valid; ++i) + if(shown < MAXSHOW) { - int32_t d = cpu_lut.lut(b, h, qb, i); - if(d < 0) - { - delta_ok = false; - break; - } - recon_kb += d; - if(recon_kb >= num_k_blocks) - { - delta_ok = false; - break; - } - if(!block_map_cpu(b, h, qb, recon_kb)) - { - delta_ok = false; - break; - } + size_t k_idx = i % num_k_blocks; + size_t q_idx = (i / num_k_blocks) % num_q_blocks; + size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead; + size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead); + std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx + << ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g) + << " cpu=" << int(c); + ++shown; } - if(valid != active_count || !delta_ok) + ++bm_mismatch; + } + } + bm_pass = (bm_mismatch == 0); + float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f; + std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total + << " (" << std::setprecision(4) << bm_ratio << "%) " + << (bm_pass ? "PASS" : "FAIL"); + + auto cpu_lut = sparge::block_map_to_vsa_lut_delta(block_map_cpu); + size_t lut_fails = 0; + for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b) + { + for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb) { - lut_pass = false; - if(lut_fails < MAXSHOW) - std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb - << ") valid=" << valid << " active=" << active_count - << " delta_ok=" << delta_ok; - ++lut_fails; + int32_t valid = cpu_lut.valid_block_num(b, h, qb); + int32_t active_count = 0; + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + if(block_map_cpu(b, h, qb, kb)) + ++active_count; + int32_t recon_kb = 0; + bool delta_ok = true; + for(int32_t i = 0; i < valid; ++i) + { + int32_t d = cpu_lut.lut(b, h, qb, i); + if(d < 0) + { + delta_ok = false; + break; + } + recon_kb += d; + if(recon_kb >= num_k_blocks) + { + delta_ok = false; + break; + } + if(!block_map_cpu(b, h, qb, recon_kb)) + { + delta_ok = false; + break; + } + } + if(valid != active_count || !delta_ok) + { + lut_pass = false; + if(lut_fails < MAXSHOW) + std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb + << ") valid=" << valid << " active=" << active_count + << " delta_ok=" << delta_ok; + ++lut_fails; + } } } } + std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL"); + } // end if(!skip_cpu_bm_check) + else + { + std::cout << "\n [block_map cross-check] SKIPPED (mask/sink active; CPU ref lacks)"; + std::cout << "\n [VSA LUT self-consistency] SKIPPED"; } - std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL"); ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); ck_tile::reference_blocked_attention( diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp index 9006ee7696..bb1cdbfec4 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -65,6 +65,12 @@ struct SpargeBlockMapKernel // Per-head cdfthreshd (size = nhead_q floats). Required (non-null); // only consulted on topk<=0 path. const float* cdfthreshd_per_head; + + // R32 Items 2+3. mask_type stored as index_t (not mask_enum) to keep this + // include-tree header independent of example/01_fmha/mask.hpp. Magic + // constant 1 == mask_enum::mask_top_left (01_fmha/mask.hpp:13-19). + index_t mask_type; + bool attention_sink; }; CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr, @@ -90,7 +96,9 @@ struct SpargeBlockMapKernel const void* pooled_k_ws_ptr, const void* sim_k_ws_ptr, const float* topk_per_head, - const float* cdfthreshd_per_head) + const float* cdfthreshd_per_head, + index_t mask_type, + bool attention_sink) { const index_t N_k = integer_divide_ceil(seqlen_k, kN0); return Kargs{q_ptr, @@ -117,7 +125,9 @@ struct SpargeBlockMapKernel sim_k_ws_ptr, N_k, topk_per_head, - cdfthreshd_per_head}; + cdfthreshd_per_head, + mask_type, + attention_sink}; } CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q) @@ -220,7 +230,9 @@ struct SpargeBlockMapKernel valid_out, pooled_k_ws, sim_k_ws, - static_cast(smem)); + static_cast(smem), + kargs.mask_type, + kargs.attention_sink); } }; diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp index e04c6e2d4f..176063cee1 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -263,10 +263,16 @@ struct SpargeBlockMapPipeline int32_t* valid_block_num_ptr, const KDataType* __restrict__ pooled_k_ws_ptr, const uint8_t* __restrict__ sim_k_ws_ptr, - void* smem_ptr) const + void* smem_ptr, + index_t mask_type, + bool attention_sink) const { const index_t tid = static_cast(threadIdx.x); + // mask_enum::mask_top_left == 1 (01_fmha/mask.hpp:16). Multiplicative + // form handles BLKQ=64,BLKK=128 (kM0=kN0 case. + const bool is_causal_tl = (mask_type == 1); + // K-loop no longer reduces; only Q-stats uses smem_float0. // smem_float1 slab is allocated for layout compat but unused. auto* smem_float0 = reinterpret_cast(smem_ptr); @@ -320,13 +326,23 @@ struct SpargeBlockMapPipeline // Not similar → force all K blocks ON, early exit if(!sim_q) { + // R32 Item 2: only fill causal-valid prefix when active. + const index_t causal_kb_end = + is_causal_tl ? min(N_k, integer_divide_ceil((qb + 1) * kM0, kN0)) : N_k; + for(index_t i = tid; i < N_k; i += kBlockSize) - block_map_ptr[i] = 1; + block_map_ptr[i] = (i < causal_kb_end) ? 1 : 0; + + // R32 Item 3: sink force. Under top-left causal, kb=0 always + // causal-valid for qb>=0 -> no-op; meaningful for mask=no + sink=1. + if(attention_sink && tid == 0) + block_map_ptr[0] = 1; + __syncthreads(); // sink visible to LUT-build below if(lut_ptr != nullptr && tid == 0) { int32_t valid = 0, prev = 0; - for(index_t kb = 0; kb < N_k; ++kb) + for(index_t kb = 0; kb < causal_kb_end; ++kb) { lut_ptr[valid] = static_cast(kb) - prev; prev = static_cast(kb); @@ -354,6 +370,10 @@ struct SpargeBlockMapPipeline for(index_t kb = 0; kb < N_k; ++kb) { + // R32 Item 2: top-left causal at block grain. + // (qb,kb) past-diagonal iff kb*kN0 >= (qb+1)*kM0. + const bool causal_killed = is_causal_tl && (kb * kN0 >= (qb + 1) * kM0); + const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread; float pooled_k_mean[KPerThread]; for(index_t k = 0; k < KPerThread; ++k) @@ -372,17 +392,17 @@ struct SpargeBlockMapPipeline // ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1) // AND have score = -inf so the selection step (topk / cdf) does NOT // pick them again (would double-count toward topk budget). - // Both writes MUST stay together. Any selection rewrite - // (e.g. iterative argmax -> bitonic sort) must keep the -inf write. - if(!sim_k) + // R32: causal_killed gates the force-on so past-diagonal blocks are + // NOT forced ON; bmap stays 0, scores -inf so selection excludes them. + if(causal_killed) + smem_scores[kb] = -numeric::infinity(); // bmap stays 0 + else if(!sim_k) { smem_bmap[kb] = 1; smem_scores[kb] = -numeric::infinity(); } else - { smem_scores[kb] = dot * scale; - } } } __syncthreads(); // guard selection's reads of smem_bmap / smem_scores @@ -517,6 +537,15 @@ struct SpargeBlockMapPipeline // ================================================================== // Write outputs to global memory // ================================================================== + // R32 Item 3: force smem_bmap[0]=1 BEFORE LUT collation reads it. + // Reuses existing LUT-build loop (R31 §4: don't manually insert into + // delta stream). Causal post-multiply unnecessary: D.2 sets killed + // scores to -inf; selection gate L490 `bv > 0` excludes them, so + // smem_bmap[bi]=1 never fires for killed blocks. + if(attention_sink && tid == 0) + smem_bmap[0] = 1; + __syncthreads(); + for(index_t i = tid; i < N_k; i += kBlockSize) block_map_ptr[i] = smem_bmap[i];