mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
kBlockQ (= kBlockM / num_queries_per_kv) was constexpr in
`UnifiedAttentionShape` / the kernel-traits, forcing one kernel
instance per (kBlockM, num_qpkv) pair even though the matmul tile is
fully determined by kBlockM and kHeadDim. Audit confirmed kBlockQ
only feeds:
* arithmetic in `unified_attention_kernel.hpp` (loop bounds, Q-tile
indexing, query_len padding),
* `pad_tensor_view` size tuples for Q/O/LSE DRAM views,
* one `mask.IsEdgeTile(... number<kBlockQ>{} ...)` call inside the
pipeline's per-K-tile mask check.
None of these structurally need a compile-time value:
* `pad_tensor_view` already accepts mixed runtime/compile-time tuple
elements (e.g. it's passed plain `1` next to `kHeadDimPadded`).
* `IsEdgeTile` only does runtime arithmetic on the tile size; adding a
runtime overload that accepts `index_t` is trivial (the compile-time
one now forwards to it).
Wiring:
* `block_masking.hpp` -- add an `IsEdgeTile(..., index_t tile_h,
index_t tile_w)` overload; the existing `number<>` overload just
forwards to it.
* `unified_attention_pipeline.hpp` -- new optional
`num_queries_per_kv` arg on the pipeline's `operator()` (default 0
keeps existing call sites unchanged). Computes
`kBlockQ_dyn = (num_qpkv > 0) ? (kBlockM / num_qpkv) : kBlockQ`
once at the top, uses it in the IsEdgeTile call.
* `unified_attention_kernel.hpp` -- compute
`const index_t kBlockQ_dyn = kBlockM / kargs.num_queries_per_kv`
once and replace every per-call `kBlockQ` use with `kBlockQ_dyn`.
Pass `kargs.num_queries_per_kv` through to the pipeline. The
debug-only assert(`kBlockQ_dyn == kBlockQ`) keeps the static and
dynamic values in lock-step until we actually collapse variants.
Perf A/B (b=4..256, sk=120000, MI300):
d=128 MHA (num_qpkv = 1, runtime div is trivial):
BW within +/-0.2% across all batch sizes (noise).
d=64 GQA-8 (num_qpkv = 8, runtime division actually happens):
speedups 1.28x..2.14x vs Triton -- identical to baseline.
Correctness suite stays at 241/245 (same 4 pre-existing int32-overflow
failures in the d=128 prefill rebased-pointer path).
This is a no-op on perf and unlocks a follow-up where we collapse the
two num_qpkv values per (head_dim, kBlockM) -- e.g. the future d=128
GQA-8 variant can reuse the existing decode_d128_mha_* instances by
just passing a different runtime num_queries_per_kv.
Co-authored-by: Cursor <cursoragent@cursor.com>