mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 00:39:02 +00:00
ck_tile: add gtest unit tests for MX flatmm (gfx950) (#5082) ## Summary - Add correctness unit tests for the MX-format flatmm kernel (`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/` - Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6, FP8×FP4, FP4×FP8 - Tests cover all four kernel dispatch paths (the `has_hot_loop` × `tail_num` product): - `has_hot_loop=false, tail=ODD` (K=256, num_loop=1) - `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2) - `has_hot_loop=true, tail=ODD` (K=768, num_loop=3) - `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4) - Remove unsupported `-split_k` CLI option from `tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with K-splitting and the option silently produced wrong results ## Changes **New files (`test/ck_tile/flatmm/`):** - `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT library, links into 5 per-dtype test executables; forwards `-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON - `test_mx_flatmm_base.hpp` — base test fixture with `run_test_with_validation(M, N, K, kbatch=1)` - `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test class and type aliases - `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype `TYPED_TEST_SUITE` files **Modified files:** - `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved `preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by both the example and the tests - `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc` — removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split` formula, updated call sites after `preShuffleWeight` move - `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)` --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.