mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK][CK_TILE] Add fp8bf16 hdim=256 tile for batch prefill (#5918) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation FP8 batch prefill kernels currently only support head_dim=128. Models with head_dim=256 hit the "invalid argument for batch_prefill" error because no matching kernel variant exists in the codegen dispatch. ## Technical Details Add a hdim=256 tile size entry for fp8bf16 in the batch prefill codegen recipe (`fmha_batch_prefill.py`). Tile configuration: `FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4,1,1, 4,1,1, 32,32,32, 32,32,32, -1)` - bm0=128, bn0=128 (Q/K tile sizes) - bn1=256, bk0max=256 (V head_dim=256) - Warp MFMA 32x32x32 (fp8 MFMA instructions) This mirrors the existing bf16/fp16 hdim=256 tile but uses fp8 warp sizes. ## Test Plan Tested on both MI308X (gfx942) and MI355X (gfx950) via aiter batch prefill test with the following matrix: - page_size: {1, 16, 1024} - kv_layout: {linear, vectorized} - lookup_table: {sglang, vllm} - causal: {true, false} - logits_soft_cap: {0.0, 30.0} - contiguous_kv: {true, false} ## Test Result **MI308X (gfx942):** 160 passed, 32 skipped (page_size=1 + vectorized not applicable) **MI355X (gfx950):** 120 passed, 72 skipped (pre-existing ROCm 7.2 compiler issue with causal + no softcap) No register spills on either platform. ### Profiling — MI355X (gfx950), FP8 pertensor, hdim=256, seqlen=1024, 8 heads | page_sz | kv_layout | table | causal | soft_cap | time_us | TFLOPS | |---------|-----------|-------|--------|----------|---------|--------| | 1 | linear | sglang | False | 0.00 | 55.01 | 156.16 | | 1 | linear | vllm | False | 0.00 | 55.12 | 155.84 | | 1 | linear | sglang | False | 30.00 | 62.63 | 137.16 | | 1 | linear | vllm | False | 30.00 | 62.16 | 138.20 | | 1 | linear | sglang | True | 30.00 | 64.09 | 67.01 | | 1 | linear | vllm | True | 30.00 | 63.85 | 67.27 | | 16 | linear | sglang | False | 0.00 | 57.00 | 150.69 | | 16 | vectorized | sglang | False | 0.00 | 57.55 | 149.25 | | 16 | linear | vllm | False | 0.00 | 56.80 | 151.23 | | 16 | vectorized | vllm | False | 0.00 | 57.32 | 149.87 | | 16 | linear | sglang | False | 30.00 | 64.77 | 132.62 | | 16 | vectorized | vllm | False | 30.00 | 63.54 | 135.18 | | 16 | linear | sglang | True | 30.00 | 66.84 | 64.26 | | 16 | vectorized | vllm | True | 30.00 | 66.12 | 64.96 | | 1024 | linear | sglang | False | 0.00 | 58.25 | 147.46 | | 1024 | vectorized | sglang | False | 0.00 | 57.53 | 149.31 | | 1024 | linear | vllm | False | 0.00 | 58.06 | 147.94 | | 1024 | vectorized | vllm | False | 0.00 | 57.55 | 149.27 | | 1024 | linear | sglang | False | 30.00 | 65.38 | 131.38 | | 1024 | vectorized | vllm | False | 30.00 | 63.64 | 134.98 | | 1024 | linear | sglang | True | 30.00 | 66.85 | 64.25 | | 1024 | vectorized | vllm | True | 30.00 | 65.26 | 65.81 | ### Profiling — MI308X (gfx942), FP8 pertensor, hdim=256, seqlen=1024, 8 heads | page_sz | kv_layout | table | causal | soft_cap | time_us | TFLOPS | |---------|-----------|-------|--------|----------|---------|--------| | 1 | linear | sglang | False | 0.00 | 110.18 | 77.96 | | 1 | linear | vllm | True | 30.00 | 134.33 | 31.97 | | 1 | linear | sglang | True | 30.00 | 134.59 | 31.91 | | 16 | linear | sglang | False | 0.00 | 115.43 | 74.42 | | 16 | vectorized | sglang | False | 0.00 | 106.11 | 80.95 | | 16 | linear | vllm | False | 0.00 | 116.34 | 73.83 | | 16 | vectorized | vllm | False | 0.00 | 106.17 | 80.91 | | 16 | linear | sglang | False | 30.00 | 135.61 | 63.34 | | 16 | vectorized | vllm | False | 30.00 | 122.37 | 70.20 | | 16 | linear | sglang | True | 0.00 | 117.44 | 36.57 | | 16 | vectorized | vllm | True | 0.00 | 108.81 | 39.47 | | 16 | linear | sglang | True | 30.00 | 139.43 | 30.80 | | 16 | vectorized | vllm | True | 30.00 | 125.87 | 34.12 | | 1024 | linear | sglang | False | 0.00 | 110.65 | 77.63 | | 1024 | vectorized | sglang | False | 0.00 | 101.70 | 84.46 | | 1024 | linear | vllm | False | 0.00 | 111.71 | 76.89 | | 1024 | vectorized | vllm | False | 0.00 | 101.55 | 84.59 | | 1024 | linear | sglang | False | 30.00 | 129.33 | 66.42 | | 1024 | vectorized | vllm | False | 30.00 | 120.95 | 71.02 | | 1024 | linear | sglang | True | 0.00 | 112.26 | 38.26 | | 1024 | vectorized | vllm | True | 0.00 | 103.02 | 41.69 | | 1024 | linear | sglang | True | 30.00 | 133.73 | 32.12 | | 1024 | vectorized | vllm | True | 30.00 | 124.75 | 34.43 | ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.