Files
Aviral Goel 004784ef98 chore(copyright) update library wide CMakeLists.txt copyright header template (#3313)
* chore(copyright) update library wide CMakeLists.txt files copyright header template

* Fix build

---------

Co-authored-by: Sami Remes <samremes@amd.com>
2025-11-28 13:49:54 -08:00
..

GEMM with K-Axis Splitting (Split-K GEMM)

This example demonstrates a General Matrix-Matrix Multiplication (GEMM) implemented with a Split-K algorithm. This is a technique used to increase the available parallelism for a single, large GEMM operation, which can lead to higher performance, especially on GPUs with a very large number of compute units.

Mathematical Formulation

A standard GEMM computes the matrix product C = A \times B, where A has shape [M, K] and B has shape [K, N]. The computation is: C_{ij} = \sum_{k=0}^{K-1} A_{ik} B_{kj}

In a Split-K algorithm, the K dimension is split into S chunks of size K_split = K / S. The GEMM is then broken down into S smaller, partial GEMMs.

For each split s from 0 to S-1:

  • Let A_s be the s-th slice of A along the K-axis (shape [M, K_split]).
  • Let B_s be the s-th slice of B along the K-axis (shape [K_split, N]).
  • A partial product is computed: C_s = A_s \times B_s.

The final result C is the sum of all the partial products: C = \sum_{s=0}^{S-1} C_s = C_0 + C_1 + \dots + C_{S-1}

Algorithmic Strategy: Parallel Reduction of Partial GEMMs

The Split-K algorithm turns a single large GEMM into multiple smaller GEMMs whose results must be reduced (summed). This introduces a new axis of parallelism.

  1. Splitting the K-Dimension: The K dimension of the input matrices A and B is logically split into S parts. The S value is chosen by the kernel based on the problem size and hardware characteristics to expose a suitable amount of parallelism.

  2. Parallel Partial GEMMs: The S partial GEMMs are executed in parallel. The GPU's grid of thread blocks is now two-dimensional, mapping not only to the M and N dimensions of the output matrix C, but also to the S splits of the K dimension.

    • A thread block is assigned to compute a tile of a partial product C_s.
  3. Reduction of Partial Results: The key challenge is how to sum the partial products C_s efficiently.

    • Atomic Add: The simplest method is for each block to compute its tile of C_s and then use atomic add operations to accumulate its result directly into the final output matrix C in global memory. This is easy to implement but can suffer from high contention on the atomic operations, especially if many splits are trying to update the same memory location.
    • Two-Stage Reduction: A more robust approach involves two stages:
      • Stage 1 (Partial Products): Each of the S parallel GEMMs writes its full partial product C_s to a temporary workspace in global memory.
      • Stage 2 (Final Reduction): A separate reduction kernel is launched to sum the S partial products from the workspace into the final output matrix C.

Composable Kernel's implementation abstracts this complexity. The DeviceGemmSplitK interface handles the selection of the split factor S, the launch of the parallel partial GEMMs, and the final reduction step.

Source Code Organization

  • splitk_gemm_xdl.cpp: The main example file. It sets up a standard GEMM problem and instantiates the DeviceGemmSplitK operation.
  • ../../include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp: The high-level device interface for the Split-K GEMM. It takes an additional k_batch parameter which controls the number of splits.
  • The underlying grid-wise kernel is modified to accept a k_batch index, so that each thread block knows which slice of the A and B matrices it is responsible for. It also includes the logic for the reduction (e.g., using atomic adds).

Build and Run

Prerequisites

Ensure the Composable Kernel library is built and installed.

cd /path/to/composable_kernel/build
make -j install

Build the Example

cd /path/to/composable_kernel/example/35_splitK_gemm
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

# Run the example with default settings
./splitk_gemm_xdl

# Run with verification, data initialization, and timing
./splitk_gemm_xdl 1 2 1

When is Split-K Useful?

Split-K is not always faster than a standard GEMM. It is most beneficial in specific scenarios:

  • "Skinny" GEMMs: For GEMMs where M and N are small but K is very large (e.g., M=64, N=64, K=65536). A standard GEMM might not generate enough parallel work to fill a large GPU. By splitting the large K dimension, we create many more independent work items, improving hardware utilization.
  • Limited Shared Memory: If a standard GEMM requires a very large tile size (and thus a large amount of shared memory) to be efficient, Split-K can be an alternative. It can use smaller tiles for the partial GEMMs, reducing the shared memory footprint per block.
  • Load Balancing: It can help with load balancing on heterogeneous hardware or in complex fused scenarios.

The trade-off is the overhead of the reduction step. The performance gain from increased parallelism must outweigh the cost of either atomic operations or writing and re-reading intermediate results.