mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK][CK_TILE] Add fp8bf16 hdim=256 tile for batch prefill (#5918) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation FP8 batch prefill kernels currently only support head_dim=128. Models with head_dim=256 hit the "invalid argument for batch_prefill" error because no matching kernel variant exists in the codegen dispatch. ## Technical Details Add a hdim=256 tile size entry for fp8bf16 in the batch prefill codegen recipe (`fmha_batch_prefill.py`). Tile configuration: `FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4,1,1, 4,1,1, 32,32,32, 32,32,32, -1)` - bm0=128, bn0=128 (Q/K tile sizes) - bn1=256, bk0max=256 (V head_dim=256) - Warp MFMA 32x32x32 (fp8 MFMA instructions) This mirrors the existing bf16/fp16 hdim=256 tile but uses fp8 warp sizes. ## Test Plan Tested on both MI308X (gfx942) and MI355X (gfx950) via aiter batch prefill test with the following matrix: - page_size: {1, 16, 1024} - kv_layout: {linear, vectorized} - lookup_table: {sglang, vllm} - causal: {true, false} - logits_soft_cap: {0.0, 30.0} - contiguous_kv: {true, false} ## Test Result **MI308X (gfx942):** 160 passed, 32 skipped (page_size=1 + vectorized not applicable) **MI355X (gfx950):** 120 passed, 72 skipped (pre-existing ROCm 7.2 compiler issue with causal + no softcap) No register spills on either platform. ### Profiling — MI355X (gfx950), FP8 pertensor, hdim=256, seqlen=1024, 8 heads | page_sz | kv_layout | table | causal | soft_cap | time_us | TFLOPS | |---------|-----------|-------|--------|----------|---------|--------| | 1 | linear | sglang | False | 0.00 | 55.01 | 156.16 | | 1 | linear | vllm | False | 0.00 | 55.12 | 155.84 | | 1 | linear | sglang | False | 30.00 | 62.63 | 137.16 | | 1 | linear | vllm | False | 30.00 | 62.16 | 138.20 | | 1 | linear | sglang | True | 30.00 | 64.09 | 67.01 | | 1 | linear | vllm | True | 30.00 | 63.85 | 67.27 | | 16 | linear | sglang | False | 0.00 | 57.00 | 150.69 | | 16 | vectorized | sglang | False | 0.00 | 57.55 | 149.25 | | 16 | linear | vllm | False | 0.00 | 56.80 | 151.23 | | 16 | vectorized | vllm | False | 0.00 | 57.32 | 149.87 | | 16 | linear | sglang | False | 30.00 | 64.77 | 132.62 | | 16 | vectorized | vllm | False | 30.00 | 63.54 | 135.18 | | 16 | linear | sglang | True | 30.00 | 66.84 | 64.26 | | 16 | vectorized | vllm | True | 30.00 | 66.12 | 64.96 | | 1024 | linear | sglang | False | 0.00 | 58.25 | 147.46 | | 1024 | vectorized | sglang | False | 0.00 | 57.53 | 149.31 | | 1024 | linear | vllm | False | 0.00 | 58.06 | 147.94 | | 1024 | vectorized | vllm | False | 0.00 | 57.55 | 149.27 | | 1024 | linear | sglang | False | 30.00 | 65.38 | 131.38 | | 1024 | vectorized | vllm | False | 30.00 | 63.64 | 134.98 | | 1024 | linear | sglang | True | 30.00 | 66.85 | 64.25 | | 1024 | vectorized | vllm | True | 30.00 | 65.26 | 65.81 | ### Profiling — MI308X (gfx942), FP8 pertensor, hdim=256, seqlen=1024, 8 heads | page_sz | kv_layout | table | causal | soft_cap | time_us | TFLOPS | |---------|-----------|-------|--------|----------|---------|--------| | 1 | linear | sglang | False | 0.00 | 110.18 | 77.96 | | 1 | linear | vllm | True | 30.00 | 134.33 | 31.97 | | 1 | linear | sglang | True | 30.00 | 134.59 | 31.91 | | 16 | linear | sglang | False | 0.00 | 115.43 | 74.42 | | 16 | vectorized | sglang | False | 0.00 | 106.11 | 80.95 | | 16 | linear | vllm | False | 0.00 | 116.34 | 73.83 | | 16 | vectorized | vllm | False | 0.00 | 106.17 | 80.91 | | 16 | linear | sglang | False | 30.00 | 135.61 | 63.34 | | 16 | vectorized | vllm | False | 30.00 | 122.37 | 70.20 | | 16 | linear | sglang | True | 0.00 | 117.44 | 36.57 | | 16 | vectorized | vllm | True | 0.00 | 108.81 | 39.47 | | 16 | linear | sglang | True | 30.00 | 139.43 | 30.80 | | 16 | vectorized | vllm | True | 30.00 | 125.87 | 34.12 | | 1024 | linear | sglang | False | 0.00 | 110.65 | 77.63 | | 1024 | vectorized | sglang | False | 0.00 | 101.70 | 84.46 | | 1024 | linear | vllm | False | 0.00 | 111.71 | 76.89 | | 1024 | vectorized | vllm | False | 0.00 | 101.55 | 84.59 | | 1024 | linear | sglang | False | 30.00 | 129.33 | 66.42 | | 1024 | vectorized | vllm | False | 30.00 | 120.95 | 71.02 | | 1024 | linear | sglang | True | 0.00 | 112.26 | 38.26 | | 1024 | vectorized | vllm | True | 0.00 | 103.02 | 41.69 | | 1024 | linear | sglang | True | 30.00 | 133.73 | 32.12 | | 1024 | vectorized | vllm | True | 30.00 | 124.75 | 34.43 | ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.