Files
composable_kernel/example/35_splitK_gemm/README.md
2025-10-16 10:13:27 +00:00

83 lines
5.2 KiB
Markdown

# GEMM with K-Axis Splitting (Split-K GEMM)
This example demonstrates a General Matrix-Matrix Multiplication (GEMM) implemented with a **Split-K** algorithm. This is a technique used to increase the available parallelism for a single, large GEMM operation, which can lead to higher performance, especially on GPUs with a very large number of compute units.
## Mathematical Formulation
A standard GEMM computes the matrix product $C = A \times B$, where `A` has shape `[M, K]` and `B` has shape `[K, N]`. The computation is:
$C_{ij} = \sum_{k=0}^{K-1} A_{ik} B_{kj}$
In a Split-K algorithm, the `K` dimension is split into `S` chunks of size `K_split = K / S`. The GEMM is then broken down into `S` smaller, partial GEMMs.
For each split `s` from `0` to `S-1`:
- Let $A_s$ be the s-th slice of `A` along the K-axis (shape `[M, K_split]`).
- Let $B_s$ be the s-th slice of `B` along the K-axis (shape `[K_split, N]`).
- A partial product is computed: $C_s = A_s \times B_s$.
The final result `C` is the sum of all the partial products:
$C = \sum_{s=0}^{S-1} C_s = C_0 + C_1 + \dots + C_{S-1}$
## Algorithmic Strategy: Parallel Reduction of Partial GEMMs
The Split-K algorithm turns a single large GEMM into multiple smaller GEMMs whose results must be reduced (summed). This introduces a new axis of parallelism.
1. **Splitting the K-Dimension**: The `K` dimension of the input matrices `A` and `B` is logically split into `S` parts. The `S` value is chosen by the kernel based on the problem size and hardware characteristics to expose a suitable amount of parallelism.
2. **Parallel Partial GEMMs**: The `S` partial GEMMs are executed in parallel. The GPU's grid of thread blocks is now two-dimensional, mapping not only to the M and N dimensions of the output matrix `C`, but also to the `S` splits of the K dimension.
- A thread block is assigned to compute a tile of a *partial* product $C_s$.
3. **Reduction of Partial Results**: The key challenge is how to sum the partial products $C_s$ efficiently.
- **Atomic Add**: The simplest method is for each block to compute its tile of $C_s$ and then use atomic add operations to accumulate its result directly into the final output matrix `C` in global memory. This is easy to implement but can suffer from high contention on the atomic operations, especially if many splits are trying to update the same memory location.
- **Two-Stage Reduction**: A more robust approach involves two stages:
- **Stage 1 (Partial Products)**: Each of the `S` parallel GEMMs writes its full partial product $C_s$ to a temporary workspace in global memory.
- **Stage 2 (Final Reduction)**: A separate reduction kernel is launched to sum the `S` partial products from the workspace into the final output matrix `C`.
Composable Kernel's implementation abstracts this complexity. The `DeviceGemmSplitK` interface handles the selection of the split factor `S`, the launch of the parallel partial GEMMs, and the final reduction step.
## Source Code Organization
- [`splitk_gemm_xdl.cpp`](./splitk_gemm_xdl.cpp): The main example file. It sets up a standard GEMM problem and instantiates the `DeviceGemmSplitK` operation.
- [`../../include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp`](../../include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp): The high-level device interface for the Split-K GEMM. It takes an additional `k_batch` parameter which controls the number of splits.
- The underlying grid-wise kernel is modified to accept a `k_batch` index, so that each thread block knows which slice of the `A` and `B` matrices it is responsible for. It also includes the logic for the reduction (e.g., using atomic adds).
## Build and Run
### Prerequisites
Ensure the Composable Kernel library is built and installed.
```bash
cd /path/to/composable_kernel/build
make -j install
```
### Build the Example
```bash
cd /path/to/composable_kernel/example/35_splitK_gemm
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
```
### Run the Example
```bash
# Run the example with default settings
./splitk_gemm_xdl
# Run with verification, data initialization, and timing
./splitk_gemm_xdl 1 2 1
```
## When is Split-K Useful?
Split-K is not always faster than a standard GEMM. It is most beneficial in specific scenarios:
- **"Skinny" GEMMs**: For GEMMs where `M` and `N` are small but `K` is very large (e.g., `M=64, N=64, K=65536`). A standard GEMM might not generate enough parallel work to fill a large GPU. By splitting the large `K` dimension, we create many more independent work items, improving hardware utilization.
- **Limited Shared Memory**: If a standard GEMM requires a very large tile size (and thus a large amount of shared memory) to be efficient, Split-K can be an alternative. It can use smaller tiles for the partial GEMMs, reducing the shared memory footprint per block.
- **Load Balancing**: It can help with load balancing on heterogeneous hardware or in complex fused scenarios.
The trade-off is the overhead of the reduction step. The performance gain from increased parallelism must outweigh the cost of either atomic operations or writing and re-reading intermediate results.