mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE][FMHA] Support gfx11 ## Motivation Add support of gfx11 architectures (RDNA3) to FMHA. ## Technical Details Distributions (matrix elements to lane registers mapping) of gfx11 WMMA are completely different from distributions of gfx9 MFMA and gfx12 WMMA. There are two cases in FMHA where this difference matters: * usage of results (matrix C) of one GEMM as input (matrix A) of another GEMM. * random number generation for dropout (implementation for gfx9 MFMA, gfx12 WMMA and host validation produce the same results). Both cases are solved by a special remapping implemented using `__builtin_amdgcn_permlanex16` and `__builtin_amdgcn_perm`. Additional changes: * FMHA tests are now build and run only for those types for which instances exist (gfx11 supports only fp16 and bf16). * Two fixes for uninitialized values (`mask.sink` and `do_fp8_static_quant`): they may contain garbage resulting in incorrect dispatching logic, sometimes tests report that there are no instance available for current parameters. * Small fix to remove expcnt(0) from s_waitcnt instruction on gfx11 when they are not requested (i.e. every time), likely has no effect on performance but makes disassembly a bit clearer. ## Test Plan ``` ninja test_ck_tile_fmha bin/test_ck_tile_fmha_fwd_fp16 bin/test_ck_tile_fmha_fwd_bf16 bin/test_ck_tile_fmha_bwd_fp16 bin/test_ck_tile_fmha_bwd_bf16 ``` ## Test Result All tests must pass (some tests may be skipped). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.