mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Promote the runtime `page_size` argument to a non-type template parameter
`kPageSize_` on UnifiedAttentionPipeline. Thread it through
unified_attention_kernel_traits and dispatch_variant<V> so the host-side
dispatcher routes on args.page_blk_size ∈ {16, 32, 64} to a constexpr-
pinned prefill instance; values outside that menu (or any decode variant)
fall back to the existing kPageSize_=0 runtime-page-size instance.
Two wins fold together on the prefill tiers:
1. Strength-reduction. Every `/ page_size`, `* page_size`, and `% page_size`
in the per-tile address chain collapses to a literal-folded shift /
multiply-by-magic (`/ 32` → shr 5, etc).
2. Wider Tier-0/Tier-2 gate. The scalar-promote + LDS-cache fast path now
uses the *real* precondition `KY0_step_N <= kPageSize` at compile time
instead of the conservative `KY0_step_N <= 16` hedge — so prefill_d128
bf16/fp16 (KY0_step_N=32), prefill_d64 fp8 (KY0_step_N=32), and
prefill_d64 bf16/fp16 (KY0_step_N=64) also enter the fast path at
their natural page sizes.
Measured impact (sq=sk=75600, MI355, n=30 iters, GQA-8):
variant KY0_step_N ps before after Δ
prefill_d128 fp8 16 32 119.0 111.5 -6.3 %
prefill_d128 bf16 32 32 132.7 130.3 -1.8 %
prefill_d64 fp8 32 32 80.9 68.1 -15.8 %
prefill_d64 bf16 64 64 74.4 73.4 -1.3 %
Decode variants stay on the kPageSize_=0 instances (Tier-0 gate gates them
out anyway — <8 warps — and the binary-size cost isn't justified). All
sweep_fp8.sh shapes + 21 multi-seed multi-sk-length prefill shapes
correctness-PASS. Pre-existing Tier-2 LDS-cache limit (4096 entries)
documented in the pipeline header — same constraint applies to the
kPageSize_=0 fallback so this is not a regression.
36 new prefill instance files: prefill_d{64,128} × {fp16, bf16, fp8} ×
{mask, nmask} × {ps16, ps32, ps64}.
Co-authored-by: Cursor <cursoragent@cursor.com>