mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
GetAlignmentK / GetAlignmentV previously returned a blanket 4 B/lane
(one dword) for every FP8/BF8 tile, citing the gfx950 LDS-direct load
constraint (only dword / dwordx3 / dwordx4 are supported). That cap was
correct for the 8-warp prefill variants (kBlockSize=512, NumIssues drops
to 0.5 at 16 B/lane) but over-applied to every decode tier, where the
1/2/4-warp tile geometry has plenty of headroom.
Refactor the alignment selector into GetKVAlignmentBytes<>, which picks
dwordx4 whenever NumIssues = kPageBlockSize*kHeadDim/(kBlockSize*16)
is an integer >= 1 and falls back to dword otherwise. BF16/FP16 paths
stay at 16 B/lane on every compiled tile, so existing perf is unchanged.
FP8 prefill_d{64,128} also keep the historical dword path because
NumIssues = 0.5 there. FP8 decode_d{64,128}_m{16,32,64,128} now use
dwordx4: same byte volume per K/V tile but 4x fewer async-load issues
(SQ_INSTS_VMEM 131M -> 33M on b=128 sq=1 sk=128000 d=64).
Wall-clock impact on the long-context decode sweep (HIP_VISIBLE_DEVICES=2,
ITERS=20, WARMUP=5, MI355):
shape dtype before after speedup
decode d=64 sq=1 sk=128000 b=128 fp8 7.17 ms 4.57 ms 1.57x
decode d=64 sq=1 sk=128000 b=256 fp8 16.24 ms 9.51 ms 1.71x
decode d=128 sq=1 sk=128000 b=128 fp8 13.11 ms 7.15 ms 1.83x
decode d=128 sq=1 sk=128000 b=256 fp8 31.37 ms 9.78 ms 3.21x
decode d=64 sq=1 sk=128000 b=4 fp8 0.42 ms 0.22 ms 1.92x
decode d=128 sq=1 sk=128000 b=4 fp8 0.80 ms 0.42 ms 1.93x
prefill d=64 sq=75600 sk=75600 b=1 fp8 81.4 ms 81.2 ms 1.00x (dword fallback)
prefill d=128 sq=75600 sk=75600 b=1 fp8 143.5 ms 143.6 ms 1.00x (dword fallback)
Correctness verified across fp8/bf16/fp16, causal/non-causal, and all 7
compiled tile variants. Full PMC + PC-sample analysis is in
ua-test-scripts/rocprof_analysis/BOTTLENECK_ANALYSIS.md section 8.
Co-authored-by: Cursor <cursoragent@cursor.com>