mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
* Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Implement review suggested changes * Sync with develop * fix pre-commit error * Add unit tests for blockscale AB-Quantization * fix pre-commit error * fix pre-commit error * fix compile error * fix compile error * fix clang-format * fix clang-format * fix enumeration values not handled in switch * rebase file * Add missing enums to data_type_sizeof (#3430) Fixes broken build on gfx942. This was some test code that got merged at the same time. * [CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419) * Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen * Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts --------- Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> * Add build trace diagnostics to CI. (#3432) * generate and visualize build traces for all archs * generate build traces in all cases * fix jenkins logic * fix typo * use more threads for parsing dependency map * add script to parse ninja traces and issue warnings * fix python script syntax and header * fix python syntax one more time * fix python syntax * Support A/B Quantization in Blockscale GEMM * Implement review suggested changes * Sync with develop * Add unit tests for blockscale AB-Quantization * fix enumeration values not handled in switch * rebase file * rebase file --------- Co-authored-by: John Shumway <jshumway@amd.com> Co-authored-by: DarylHawkinsAMD <Daryl.Hawkins@amd.com> Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
CK Tile Example Suite
This directory contains a comprehensive suite of examples demonstrating the CK Tile programming model for high-performance GPU kernels. Each example illustrates a key deep learning or HPC operation, implemented using tile-based parallelism, modular pipelines, and data movement policy.
What is CK Tile?
CK Tile is a composable GPU programming API that expresses kernels as a composition of "tiles"—rectangular blocks of computation and data movement. The pipeline & policy orchestrates data movement (global <-> LDS <-> registers), computation, and synchronization, enabling high efficiency and flexibility.
Example Index
| Example | Operation | Description |
|---|---|---|
| 01_fmha | Fused Multi-Head Attention | Tile-based FMHA with masking, quantization, and epilogue fusion |
| 02_layernorm2d | LayerNorm2D | Blockwise layer normalization with fusion and quantization |
| 03_gemm | GEMM | Matrix multiplication with tilewise parallelism |
| 04_img2col | im2col | Image-to-column transformation for GEMM-based convolution |
| 05_reduce | Reduction | Tilewise sum, max, mean reductions |
| 06_permute | Permute | Generic tensor permutation (up to rank-8) |
| 09_topk_softmax | TopK-Softmax | Rowwise softmax and top-k selection for MoE gating |
| 10_rmsnorm2d | RMSNorm2D | Root mean square normalization for LLMs |
| 11_add_rmsnorm2d_rdquant | Add + RMSNorm2D + RDQuant | Fused add, RMSNorm, and rowwise dynamic quantization |
| 12_smoothquant | SmoothQuant | Per-channel scaling and quantization for int8 inference |
| 13_moe_sorting | MoE Sorting | Token-to-expert rearrangement for MoE dispatch |
| 14_moe_smoothquant | MoE-SmoothQuant | Expert-dependent quantization fused with top-k selection |
| 15_fused_moe | Fused MoE | End-to-end fused MoE block: sorting, group-GEMM, activation, weighting |
| 16_batched_gemm | Batched GEMM | Parallel computation of multiple GEMMs |
| 17_grouped_gemm | Grouped GEMM | Multiple independent GEMMs with different shapes |
| 18_flatmm | FLATMM | Flattened matrix multiplication for packed layouts |
| 19_gemm_multi_d | Multi-D GEMM | GEMM with multiple side inputs (bias, residual, etc.) |
| 35_batched_transpose | Batched Transpose | NCHW <-> NHWC and other layout conversions |
| 36_copy | Copy | Minimal example for tile-based memory movement |
| 37_transpose | Block Transpose | High-performance tiled transpose for large tensors |
Technical Highlights
- Tile Distribution: See
include/ck_tile/tile_program/tile_distribution/for mapping tiles to thread blocks. - Block Tile Pipelines: See
include/ck_tile/tile_program/block_tile_pipeline/for memory/computation pipelines. - Policies and Utilities: Many examples use custom policies for tile/block size and memory access.
How to Build & Run
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make -j
Each example produces its own executable in build/bin/.
Learning and Extending
- Start Simple: Try 03_gemm or 36_copy to learn tile basics.
- Explore Fusion: See 11_add_rmsnorm2d_rdquant, 15_fused_moe, or 14_moe_smoothquant for advanced fusion.
- Experiment: Modify tile sizes, layouts, or pipelines to explore performance and flexibility.