mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Generalize the unified attention pipeline to support NumWarpGroups=1 (single warp group) with a serial K/V loop, in addition to the existing NumWarpGroups=2 interleaved pipeline. New decode kernel traits use 4 warps (sequence<4,1,1>) with kBlockM=128 and kBlockQ=16 for GQA-8, reducing Q tile padding waste from 31/32 to 15/16 for decode workloads (max_seqlen_q=1). Host-side dispatch (is_decode_shape) routes low-token workloads to the decode kernel automatically. Benchmark results on d64 GQA-8 (via aiter): - 64-seq decode: 2.2x slower -> 1.27x slower (1.73x speedup) - 512-seq decode: 3.5x slower -> 1.6x slower (2.2x speedup) - 1-seq decode: 0.83x (CK wins) -> 0.81x (no regression) - Prefill: unchanged (uses original 8-warp kernel) Made-with: Cursor