Sparge Attention (Composable Kernel)
A Composable Kernel port of SpargeAttn for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via -pipeline=vsa (default, faster) and -pipeline=jenga (async K/V load variant).
Status vs Upstream
Implemented:
- per-block mean-pool, cosine similarity, pooled QK
- top-k /
cdfthreshdblock selection, BlockMap LUT - sparse FMHA (both
vsaandjengabackends) - per-head
topk/simthreshd1/cdfthreshd
Not yet ported (upstream pinned to commit ae5b629):
- K smoothing — pre-pool
k -= km; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) (spas_sage_attn/core.py:L53) - is_causal mask in pooled score — required for causal-LM prefill (Llama, Qwen) (spas_sage_attn/utils.py:L338)
- attention_sink — column 0 forced ON; upstream is hard-wired to
Trueat inference (spas_sage_attn/autotune.py:L355) - Sort-based top-k selection — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) (spas_sage_attn/utils.py:L345)
- Q/K int8 quant fusion in pool kernel — enables a downstream int8 GEMM0 in the attn kernel (spas_sage_attn/utils.py:L371)
PV-skip modes
pv_threshold per-Q-tile skip in the attention kernel is implemented in three variants, selectable at runtime via -pv_mode={none|warp|block}:
none— skip disabled; baseline matching the no-PV-skip codegen instance.warp(per-wavefront) — each wavefront votes locally via__shfl_xorbutterfly AND; SGPR-resident flag. CK-tile-specific variant, not in upstream.block(per-block) — block-wide consensus vote via LDS broadcast; aligned with upstream sm80 (qk_int_sv_f16_cuda_sm80.cuh:L334). V loads stay unconditional in all modes — the guard wraps the PV MMA only, matching upstream and paper Algorithm 1.
MI300X, b=2 h=16 s=8192 d=128 fp16, 5 seeds × 9 sparsity points. All three modes dispatch to the kM0=64 padK=0 tile bucket at this shape.
On the canonical recipe shape, none > warp > block at every measured sparsity, with no crossover. The per-block guard adds +33..+35 VGPR (6..9 spills) on this tile configuration, depressing occupancy. warp is +0..+4 VGPR. The default is -pv_mode=warp; switch to none for the no-skip baseline or block to exercise the upstream-aligned variant. A shape sweep is needed before recommending block as default — the kM0=128 path has Δ ≈ 0 VGPR for per-block and is a candidate.
Usage
ninja tile_example_sparge
./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001
Select a PV-skip variant with -pv_mode={none|warp|block} (default warp); finite -pv_threshold=20 lets the per-Q-tile skip predicate fire.
Add -v=1 for CPU validation; use a small shape (-b=1 -h=2 -s=512), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.
