Files
composable_kernel/example/ck_tile/50_sparse_attn/README.md
Gino Lu b00e5449c8 sparse_attn: split KStats kernel, add README + perf charts
- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
  per-block K stats workspace consumed by Kernel B), removing redundant
  K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
  to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
  + reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
2026-05-05 03:13:24 -04:00

3.4 KiB
Raw Blame History

Sparge Attention (Composable Kernel)

A Composable Kernel port of SpargeAttn for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via -pipeline=vsa (default, faster) and -pipeline=jenga (async K/V load variant).

Status vs Upstream

Implemented:

  • per-block mean-pool, cosine similarity, pooled QK
  • top-k / cdfthreshd block selection, BlockMap LUT
  • sparse FMHA (both vsa and jenga backends)
  • per-head topk / simthreshd1 / cdfthreshd

Not yet ported (upstream pinned to commit ae5b629):

Performance

At b=2 h=32 s=16384 fp16, sparge (vsa backend) reaches 1.78× FMHA throughput at topk=0.4 and 5.04× at topk=0.1, and stays above 1.0× across the full topk range.

Speedup vs sparsity

Speedup vs FMHA, b=2 h=32 s=16384 d=128 fp16. Shape chosen to match Fig. 10 of the SpargeAttn paper (arXiv:2502.18137; Mochi-1, 22K context, head_dim=128); s=16384 is the closest grid point. Gray-outlined points have >30% inter-rep spread.

Kernel breakdown

BlockMap (_pre) stacked on attention (_attn), b=2 h=32 d=128 fp16 topk=0.4. BlockMap is roughly 17% of total at s=16384.

Usage

ninja tile_example_sparge
./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001

Add -v=1 for CPU validation; use a small shape (-b=1 -h=2 -s=512), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.

References