mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 00:39:02 +00:00
=?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?= =?UTF-8?q?More=20accurate=20tests=20for=20MmaPipelines=20(#6212)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR solves several issues: #### More accurate tests for MmaPipelines The current tests for the MmaPipelines (test_amdgcn_sparse_mma, test_amdgcn_wavewise_mma) use explicit input fragment vectors filled with 1s, and only check the output of a single lane. We should have tests that actually use the MmaPipelines with non-trivial input matrices and verify the complete output. Some other aspects of the current MmaPipelines tests that I noticed and deserve some attention: 1. There is sometimes iteration over K outside of the pipeline, which is then included in WaveTileK or FragK, which is not correct. We should remove it, move K iteration inside of the pipeline, or be more clear about this outer-K loop size and how it propagates downwards. 2. There is very tight coupling between the kernel, gtest code, and test_pipeline helper, requiring a lot of information and functions to be passed back and forth. 3. The test_pipeline helper is doing a bunch of register-related logic on the host (related to point 1) 4. Without this register logic the only thing it does is check the device, call the kernel, and check the output, but with a lot of boilerplate. #### Test helper for detecting target arch at HOST runtime There is a really apparent issue we faced while writing tests: Scenario: 1. Compile a test that supports both gfx950 and gfx1201 for gfx950 2. Run the test on a server that only has gfx1201 GPU Actual: Segmentation fault Expected: The test can correctly detect from HOST runtime that the DEVICE target_id was different and skips the test. Notes: The only way of detecting the COMPILER_TARGET_ID in the existing "arch" framework is launching a kernel and calling `get_compiler_target()` (so, from a DEVICE code). This will create a segmentation fault if the current arch differs from the target arch. To cope with this issue, we propose to export the compiler target(s) (note they can be many) through `projects/composablekernel/test/ck_tile/core/arch/CMakeLists.txt` and define a test helper to deal with such cases. #### Add composition support to Transforms We have a small number of Transforms which act on MmaOp input and output data, before and after the MmaOp call respectively. These are currently implemented to work on an MmaTile level, but in theory they are also supposed to work at a WaveTile level, i.e. after composition of multiple MmaTiles to create larger effective MNK dimensions. Currently the composed MmaTiles look like 2D C-style arrays of the individual MmaTile level register vectors (see WaveWiseMmaPipeline). The transforms should be able to take these and perform the proper transforms to the whole WaveTile at once. This might allow for better performing transformations. Note: This PR handles the SparseTransform case and if we don't end up doing scale as a transformation, there isn't really much left to do. If we end up having only the sparse transform as a non-trivial transform, then we could also consider removing the Transform framework.
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.