[CK TILE] fix numerical errors of preshuffle_b This pull request introduces several improvements and fixes related to quantized grouped GEMM (General Matrix Multiply) pipelines and their supporting utilities. # The numerical issue ## Steps to reproduce ```bash Run ./bin/tile_example_gemm_weight_preshuffle -prec=fp8 ./bin/tile_example_gemm_weight_preshuffle -prec=int4 ``` # Solution The main changes address type correctness, improve data layout and shuffling logic, and expand test coverage to better validate different GEMM configurations. **Key changes include:** ### Data layout and shuffling logic * Refactored the logic in `shuffle_b_permuteN` to use `constexpr` variables for `KLane` and `ItemsPerAccess`, simplifying tile view construction and correcting the permutation order for improved efficiency and correctness (`tensor_shuffle_utils.hpp`). * Fixed the calculation of `KLaneBytes` in weight preshuffle pipeline policies to account for internal data type conversion (e.g., from `pk_int4_t` to `fp8`), ensuring accurate memory access and alignment in quantized GEMM policies (`wp_pipeline_agmem_bgmem_creg_base_policy.hpp`, `gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp`). [[1]](diffhunk://#diff-93f16cd76e6e24404777e682a5ac8e039913ddd6a438c7efd61fdda42276e4efL274-R275) [[2]](diffhunk://#diff-9c3d0fc3c014feed435bfd93ba1f8f9fb3e054dcc322deada3addf70bee5a58cL100-R105) ### Test infrastructure enhancements * Unit tests did not catch this issue since there were no tests for fp8. Added new configuration structs (`config_mn_16x16`, `config_mn_32x32`) to support additional GEMM tile shapes and updated tests to run with these configurations for broader coverage (`test_gemm_pipeline_util.hpp`). [[1]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8R86-R103) [[2]](diffhunk://#diff-5a5962b2c4aa7f6a87d1d6201ad383135e30df13b42654e997d870d57420d5b8L255-R269) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Quick Tour for New Users
The Grouped GEMM operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads.
Preshuffle and Persistence
The grouped GEMM examples include the following advanced optimization features:
Weight Preshuffle
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
- Implementation: Available in
grouped_gemm_preshuffle.cpp - Configuration: Uses
GemmConfigPreshuffleDecodeandGemmConfigPreshufflePrefilltemplate configuration - Constraints: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
- Template Parameter: Controlled by the
Persistentboolean template parameter ininvoke_gemm - Usage:
invoke_gemm<ALayout, BLayout, CLayout, true>enables persistence
Multi-D Operations
Multi-D operations extend the standard GEMM operation by supporting additional elementwise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
- Implementation: Available in
grouped_gemm_multi_d.cpp - Operation: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result)
- Configuration: Uses
GemmConfigV3,GemmConfigV4,GemmConfigMemorytemplate configuration with 2 D tensors - Data Types: Supports fp16, bf16, fp8
- Benefits: Enables complex operations like scaling, activation functions, or other elementwise transformations in a single kernel call
- Build Target:
make tile_example_grouped_gemm_multi_d -j
Multi-D operations supports both persistence and non-persistence modes. Weight preshuffle supports only on non-persistence mode.
Build
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
# The multi-D operations example
make tile_example_grouped_gemm_multi_d -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
Each example will result in an corresponding executable build/bin/tile_example_grouped_gemm, build/bin/tile_example_grouped_gemm_preshuffle, build/bin/tile_example_grouped_gemm_multi_d, and build/bin/tile_example_quant_grouped_gemm.
example
args:
-Ms M dimensions - (Default: empty).
-Ns N dimensions - (Default: empty).
-Ks K dimensions - (Default: empty).
-stride_As Tensor A strides - (Default: empty).
-stride_Bs Tensor B strides - (Default: empty).
-stride_Cs Tensor C strides - (Default: empty).
-a_layout A tensor data layout - (Default: Row).
-b_layout B tensor data layout - (Default: Col).
-c_layout C tensor data layout - (Default: Row).
-prec data type. fp16/bf16/fp8 - (Default: fp16).
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
-warmup Number of iterations before benchmark the kernel. (Default: 10).
-repeat Number of iterations to benchmark the kernel. (Default: 100).
-group_count Group count. (Default: 16).
-kbatch kbatch for SplitK (Default: 1).
-json 0: No Json, 1: Dump Results in Json format (Default: 0).
-jsonfile json file name to dump results (Default: grouped_gemm.json).
If any of Ms, Ns, Ks, stride_As, stride_Bs, or stride_Cs are missing or their sizes
don't match group_count, the example generates defaults per group index i (0-based):
M[i] = 256 + 256 * i
N[i] = 256 + 512 * i
K[i] = 512 + 384 * i
stride_A[i] = K[i]
stride_B[i] = K[i]
stride_C[i] = N[i]
Source Structure
- Kernel:
grouped_gemm.hpp(tile-programming kernel template) - Executables:
grouped_gemm.cpp - Build:
CMakeLists.txt,run_grouped_gemm_example.inc
Related CK Tile Examples
- 16_batched_gemm: Batched GEMM with tiles
- 15_fused_moe: Fused MoE block (uses grouped GEMM)
- 03_gemm: Single GEMM with tiles
For distribution, see include/ck_tile/tile_program/tile_distribution/.