Files
composable_kernel/include
juuso-oskari 614afea7eb CK-UA: derive kBlockQ at runtime, decouple from variant template
kBlockQ (= kBlockM / num_queries_per_kv) was constexpr in
`UnifiedAttentionShape` / the kernel-traits, forcing one kernel
instance per (kBlockM, num_qpkv) pair even though the matmul tile is
fully determined by kBlockM and kHeadDim. Audit confirmed kBlockQ
only feeds:

  * arithmetic in `unified_attention_kernel.hpp` (loop bounds, Q-tile
    indexing, query_len padding),
  * `pad_tensor_view` size tuples for Q/O/LSE DRAM views,
  * one `mask.IsEdgeTile(... number<kBlockQ>{} ...)` call inside the
    pipeline's per-K-tile mask check.

None of these structurally need a compile-time value:

* `pad_tensor_view` already accepts mixed runtime/compile-time tuple
  elements (e.g. it's passed plain `1` next to `kHeadDimPadded`).
* `IsEdgeTile` only does runtime arithmetic on the tile size; adding a
  runtime overload that accepts `index_t` is trivial (the compile-time
  one now forwards to it).

Wiring:
  * `block_masking.hpp` -- add an `IsEdgeTile(..., index_t tile_h,
    index_t tile_w)` overload; the existing `number<>` overload just
    forwards to it.
  * `unified_attention_pipeline.hpp` -- new optional
    `num_queries_per_kv` arg on the pipeline's `operator()` (default 0
    keeps existing call sites unchanged). Computes
    `kBlockQ_dyn = (num_qpkv > 0) ? (kBlockM / num_qpkv) : kBlockQ`
    once at the top, uses it in the IsEdgeTile call.
  * `unified_attention_kernel.hpp` -- compute
    `const index_t kBlockQ_dyn = kBlockM / kargs.num_queries_per_kv`
    once and replace every per-call `kBlockQ` use with `kBlockQ_dyn`.
    Pass `kargs.num_queries_per_kv` through to the pipeline. The
    debug-only assert(`kBlockQ_dyn == kBlockQ`) keeps the static and
    dynamic values in lock-step until we actually collapse variants.

Perf A/B (b=4..256, sk=120000, MI300):

  d=128 MHA (num_qpkv = 1, runtime div is trivial):
    BW within +/-0.2% across all batch sizes (noise).

  d=64 GQA-8 (num_qpkv = 8, runtime division actually happens):
    speedups 1.28x..2.14x vs Triton -- identical to baseline.

Correctness suite stays at 241/245 (same 4 pre-existing int32-overflow
failures in the d=128 prefill rebased-pointer path).

This is a no-op on perf and unlocks a follow-up where we collapse the
two num_qpkv values per (head_dim, kBlockM) -- e.g. the future d=128
GQA-8 variant can reuse the existing decode_d128_mha_* instances by
just passing a different runtime num_queries_per_kv.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 12:01:59 +00:00
..