mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
* add page_block_size parameter * add is_sglang_layout to parameters * add kv_offset_array_transform to batch async for page size 16 * add kv_last_page_lens to kernel * change kv layout to [num_total_pages, page_block_size, hdim] * format * - enable codegen of batch_prefill kernels - create new problem struct BlockFmhaBatchPrefillPipelineProblem for batch prefill kernels - generate different page sizes of batch prefill kernels (1, 16) * 1. fix wrong calculation of page id in kv_offset_array_transform in gfx950 2. support page size 1024 * fix python format * change kv cache layout to [num_blocks, num_kv_heads, head_size/x, block_size, x] and [num_blocks, num_kv_heads, block_size/X, head_size, X] * 1. Introduced `kVectorSize` in BlockFmhaBatchPrefillPipelineProblem instead of using hardcode values 2. Makes batch prefill kernel traits structures inherent from fmha fwd traits 3. Add some static check for Page size, vector size, hdim, ..., etc. * [Refactor] Replace is_sglang_layout with Enums for KV cache configuration Refactored `fmha_batch_prefill` to use `BlockAttentionKVCacheMemoryLayoutEnum` (VECTORIZED/LINEAR) and `BlockAttentionKVCacheLookupTableEnum` (SGLANG_1D/VLLM_2D) instead of a single boolean. **Changes:** * Added Enum definitions in `block_attention_kvcache_layout_enum.hpp`. * Updated Kernel, Pipeline, and Traits to template on these Enums. * Implemented `kv_offset_array_transform` logic based on `kKVMemoryLayout`. * Refactored `PageBlockTableKargs` to adapt to `kKVLookupTable`. * Updated CodeGen scripts to support new parameters. This decouples memory layout from the paging mechanism, enabling flexible KV cache configurations. * 1. remove batch prefill pipeline with sk_pad=false 2. correct some comments 3. add static assert to make sure v offsets is in same page within a tile. * fix vgpr spill count * remove unnecessary t2s functions * add fp8 support for receipt 200 and 600 in fmha_bath_prefill.py * support linear kv cache layout * Remove block_table_ptr from fwd_batch_prefill_args. Instead, reuse kv_page_indices as a pointer of the lookup table. * 1. merge multiple transforms into single transform. 2. add static check to make sure vlayout is row-major. * move FmhaFwdCommonKargs::seqlen_k_ptr to VllmPageTableKargs. * update changelog --------- Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.