mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[CKTile] Fix MX GEMM: num_loop==3 dispatch, split-K, unsupported-shape guard (#6663) Three independent MX GEMM correctness bugs reported against example/ck_tile/42_mx_gemm (fp8xfp8, A=Row/B=Col) on MI350X, plus one host-side atomic-add accumulation bug in the example's repeat loop. - Pipeline (gemm_pipeline_ag_bg_cr_comp_async.hpp): BlockHasHotloop required num_loop > PrefetchStages, which let num_loop == 3 enter a hot loop that produced 5 gemm accumulations instead of 3 (K == 3*K_Tile, e.g. K=768, deterministically wrong). Require num_loop >= 4 instead: pre-pipeline + TailNumber::Three already totals exactly 3. - Kernel (gemm_mx_kernel.hpp): split-K was silently broken because GridSize did not thread k_batch into blockIdx.z and the scale tile windows were anchored at K=0 for every k_id. Every k_id >= 1 therefore read the wrong packed scales. Fix: * GridSize returns dim3(grid_x, 1, k_batch) (persistent and non-persistent). * MakeScaleA/BBlockWindows accept a k_elem_offset and translate it to a packed-scale K offset (also apply pad_tensor_view so OOB scale loads return zero, matching A/B padding). * operator() derives k_id from blockIdx.z, uses GetSplitKElemOffset (matches Underlying::SplitKBatchOffset's K1-aligned formula), and dispatches the epilogue with memory_operation_enum::atomic_add for k_batch > 1, set for k_batch == 1. Same fp16/bf16 even-vector-size guard as UniversalGemmKernel. * MakeCBlockWindows templated on DstInMemOp; unconditionally applies pad_tensor_view using kPadM/kPadN so partial trailing M/N tiles are handled correctly. - Compile- and runtime unsupported-shape guards (gemm_mx_kernel.hpp): add IsSupportedArgument and a static_assert for configurations that produce silent wrong results: * static_assert(!kPadK) -- the MX comp-async pipeline uses async_load_tile whose OOB check is per-vector-start, so a vector straddling the K pad boundary reads garbage. Until the async path learns per-element pad masking, reject kPadK at compile time. * Runtime: k_batch >= 1; M/N multiples of MPerBlock/NPerBlock when kPadM/kPadN are false; M >= MPerBlock and N >= NPerBlock always (CShuffleEpilogue cannot safely run with a single partial tile); K % (KPerBlock * k_batch) == 0; and for k_batch > 1, K must be a multiple of WarpTile_K * k_batch so every split lands on a packed-scale boundary. * All error paths log under CK_TILE_LOGGING with actionable messages. - Example (example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp): * Call Kernel::IsSupportedArgument up front and throw a clear runtime_error for rejected shapes (was silently launching an unsupported kernel). * Switch to launch_kernel_time_mask with a clear_gemm_output preprocess that zeroes C between iterations when k_batch > 1 (mirrors universal_gemm_invoker). Without this the default -warmup=50 -repeat=100 accumulated 150 atomic_adds into C after the kernel-side split-K fix. Tests (test/ck_tile/gemm_mx/): - Add MXfp8_GemmConfig16_PadMN (kPadM = kPadN = true). - test_mx_gemm_fp8.cpp: HotLoopTailNumLoopThree (K=768 regression), SplitK (k_batch=2,4 across full_k/partial_k paths), TestMxGemmFp8PadMN::{MNPaddingAligned, MPadding, NPadding, MNPadding} covering trailing partial tiles along M, N, or both. - Run(...) now takes k_batch. - packScalesMNxK: guard against OOB (mn, k) reads from src and initialise e8m0 bytes to the zero exponent (0x00) instead of the default-constructed NaN (0xFF), so padded lanes don't poison the packed int32_t shared with in-range lanes. - test_mx_gemm_instance.hpp: call IsSupportedArgument before launch. Verification on gfx950, ROCm 7.2.0: - ctest -R test_ck_tile_mx_gemm -> 100% (2/2). - Example sweep over the original bug-report shapes: all K-aligned shapes now validate correct (including 4096^3 sk=2 and the K=768 cases); all K=128 shapes cleanly rejected with the new error message instead of producing silent wrong results. Made-with: Cursor ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.