* chore(copyright) update library wide CMakeLists.txt files copyright header template * Fix build --------- Co-authored-by: Sami Remes <samremes@amd.com>
GEMM with K-Axis Splitting (Split-K GEMM)
This example demonstrates a General Matrix-Matrix Multiplication (GEMM) implemented with a Split-K algorithm. This is a technique used to increase the available parallelism for a single, large GEMM operation, which can lead to higher performance, especially on GPUs with a very large number of compute units.
Mathematical Formulation
A standard GEMM computes the matrix product C = A \times B, where A has shape [M, K] and B has shape [K, N]. The computation is:
C_{ij} = \sum_{k=0}^{K-1} A_{ik} B_{kj}
In a Split-K algorithm, the K dimension is split into S chunks of size K_split = K / S. The GEMM is then broken down into S smaller, partial GEMMs.
For each split s from 0 to S-1:
- Let
A_sbe the s-th slice ofAalong the K-axis (shape[M, K_split]). - Let
B_sbe the s-th slice ofBalong the K-axis (shape[K_split, N]). - A partial product is computed:
C_s = A_s \times B_s.
The final result C is the sum of all the partial products:
C = \sum_{s=0}^{S-1} C_s = C_0 + C_1 + \dots + C_{S-1}
Algorithmic Strategy: Parallel Reduction of Partial GEMMs
The Split-K algorithm turns a single large GEMM into multiple smaller GEMMs whose results must be reduced (summed). This introduces a new axis of parallelism.
-
Splitting the K-Dimension: The
Kdimension of the input matricesAandBis logically split intoSparts. TheSvalue is chosen by the kernel based on the problem size and hardware characteristics to expose a suitable amount of parallelism. -
Parallel Partial GEMMs: The
Spartial GEMMs are executed in parallel. The GPU's grid of thread blocks is now two-dimensional, mapping not only to the M and N dimensions of the output matrixC, but also to theSsplits of the K dimension.- A thread block is assigned to compute a tile of a partial product
C_s.
- A thread block is assigned to compute a tile of a partial product
-
Reduction of Partial Results: The key challenge is how to sum the partial products
C_sefficiently.- Atomic Add: The simplest method is for each block to compute its tile of
C_sand then use atomic add operations to accumulate its result directly into the final output matrixCin global memory. This is easy to implement but can suffer from high contention on the atomic operations, especially if many splits are trying to update the same memory location. - Two-Stage Reduction: A more robust approach involves two stages:
- Stage 1 (Partial Products): Each of the
Sparallel GEMMs writes its full partial productC_sto a temporary workspace in global memory. - Stage 2 (Final Reduction): A separate reduction kernel is launched to sum the
Spartial products from the workspace into the final output matrixC.
- Stage 1 (Partial Products): Each of the
- Atomic Add: The simplest method is for each block to compute its tile of
Composable Kernel's implementation abstracts this complexity. The DeviceGemmSplitK interface handles the selection of the split factor S, the launch of the parallel partial GEMMs, and the final reduction step.
Source Code Organization
splitk_gemm_xdl.cpp: The main example file. It sets up a standard GEMM problem and instantiates theDeviceGemmSplitKoperation.../../include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp: The high-level device interface for the Split-K GEMM. It takes an additionalk_batchparameter which controls the number of splits.- The underlying grid-wise kernel is modified to accept a
k_batchindex, so that each thread block knows which slice of theAandBmatrices it is responsible for. It also includes the logic for the reduction (e.g., using atomic adds).
Build and Run
Prerequisites
Ensure the Composable Kernel library is built and installed.
cd /path/to/composable_kernel/build
make -j install
Build the Example
cd /path/to/composable_kernel/example/35_splitK_gemm
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
Run the Example
# Run the example with default settings
./splitk_gemm_xdl
# Run with verification, data initialization, and timing
./splitk_gemm_xdl 1 2 1
When is Split-K Useful?
Split-K is not always faster than a standard GEMM. It is most beneficial in specific scenarios:
- "Skinny" GEMMs: For GEMMs where
MandNare small butKis very large (e.g.,M=64, N=64, K=65536). A standard GEMM might not generate enough parallel work to fill a large GPU. By splitting the largeKdimension, we create many more independent work items, improving hardware utilization. - Limited Shared Memory: If a standard GEMM requires a very large tile size (and thus a large amount of shared memory) to be efficient, Split-K can be an alternative. It can use smaller tiles for the partial GEMMs, reducing the shared memory footprint per block.
- Load Balancing: It can help with load balancing on heterogeneous hardware or in complex fused scenarios.
The trade-off is the overhead of the reduction step. The performance gain from increased parallelism must outweigh the cost of either atomic operations or writing and re-reading intermediate results.