Files
composable_kernel/example/ck_tile/50_sparse_attn
..

Sparge Attention (Composable Kernel)

A Composable Kernel port of SpargeAttn for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via -pipeline=vsa (default, faster) and -pipeline=jenga (async K/V load variant).

Status vs Upstream

Implemented:

  • per-block mean-pool, cosine similarity, pooled QK
  • top-k / cdfthreshd block selection, BlockMap LUT
  • sparse FMHA (both vsa and jenga backends)
  • per-head topk / simthreshd1 / cdfthreshd

Not yet ported (upstream pinned to commit ae5b629):

PV-skip modes

pv_threshold per-Q-tile skip in the attention kernel is implemented in three variants, selectable at runtime via -pv_mode={none|warp|block}:

  • none — skip disabled; baseline matching the no-PV-skip codegen instance.
  • warp (per-wavefront) — each wavefront votes locally via __shfl_xor butterfly AND; SGPR-resident flag. CK-tile-specific variant, not in upstream.
  • block (per-block) — block-wide consensus vote via LDS broadcast; aligned with upstream sm80 (qk_int_sv_f16_cuda_sm80.cuh:L334). V loads stay unconditional in all modes — the guard wraps the PV MMA only, matching upstream and paper Algorithm 1.

PV-skip mode comparison

MI300X, b=2 h=16 s=8192 d=128 fp16, 5 seeds × 9 sparsity points. All three modes dispatch to the kM0=64 padK=0 tile bucket at this shape.

On the canonical recipe shape, none > warp > block at every measured sparsity, with no crossover. The per-block guard adds +33..+35 VGPR (6..9 spills) on this tile configuration, depressing occupancy. warp is +0..+4 VGPR. The default is -pv_mode=warp; switch to none for the no-skip baseline or block to exercise the upstream-aligned variant. A shape sweep is needed before recommending block as default — the kM0=128 path has Δ ≈ 0 VGPR for per-block and is a candidate.

Usage

ninja tile_example_sparge
./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001

Select a PV-skip variant with -pv_mode={none|warp|block} (default warp); finite -pv_threshold=20 lets the per-Q-tile skip predicate fire.

Add -v=1 for CPU validation; use a small shape (-b=1 -h=2 -s=512), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.

References