Files
composable_kernel/example/ck_tile/50_sparse_attn
Gino Lu 840b8a37d9 test(sparse_attn): CPU-ref cross-check + BLKQ cite
Wire SpargeAttn CPU reference into test_sparge: build the block_map on host via
sparge::build_block_map_meansim and cross-check against the GPU-produced map;
self-check the VSA delta-LUT (valid count + reachable kb indices); split PASS/FAIL
into separate block_map / LUT / attention-output lines for clearer diagnosis.

Set sparge_tool::SpargeParams::BLKQ default to 64 to match SpargeAttn SM90
convention (cite upstream qk_int_sv_f8_cuda_sm90.cu:143-144); tighten bf16
tolerance back to the dense FMHA baseline (4e-2 atol, 1e-2 rtol).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
2026-05-17 02:35:51 -04:00
..
2026-04-22 13:13:37 -04:00

Sparge Attention (Composable Kernel)

A Composable Kernel port of SpargeAttn for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via -pipeline=vsa (default, faster) and -pipeline=jenga (async K/V load variant).

Status vs Upstream

Implemented:

  • per-block mean-pool, cosine similarity, pooled QK
  • top-k / cdfthreshd block selection, BlockMap LUT
  • sparse FMHA (both vsa and jenga backends)
  • per-head topk / simthreshd1 / cdfthreshd

Not yet ported (upstream pinned to commit ae5b629):

Performance

At b=2 h=32 s=16384 fp16, sparge (vsa backend) reaches 1.78× FMHA throughput at topk=0.4 and 5.04× at topk=0.1, and stays above 1.0× across the full topk range.

Speedup vs sparsity

Speedup vs FMHA, b=2 h=32 s=16384 d=128 fp16. Shape chosen to match Fig. 10 of the SpargeAttn paper (arXiv:2502.18137; Mochi-1, 22K context, head_dim=128); s=16384 is the closest grid point. Gray-outlined points have >30% inter-rep spread.

Kernel breakdown

BlockMap (_pre) stacked on attention (_attn), b=2 h=32 d=128 fp16 topk=0.4. BlockMap is roughly 17% of total at s=16384.

Usage

ninja tile_example_sparge
./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001

Add -v=1 for CPU validation; use a small shape (-b=1 -h=2 -s=512), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.

References