Files
Vidyasagar Ananthan 92c67a824f [DOCS] Documentation Addition (Readme updates) (#2495)
* GH-2368 Adding a basic glossary

GH-2368 Minor edits

GH-2368 Adding missing READMEs and standardization.

resolving readme updates

GH-2368 Minor improvements to documentation.

Improving some readmes.

Further improvement for readmes.

Cleaned up the documentation in 'client_example' (#2468)

Update for PR

Update ACRONYMS.md to remove trivial terms

Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine.

revise 37_transpose readme

revise 36_copy readme

Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity.

Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity.

Remove references to the Tile Engine in README files across multiple examples

* GH-2368 Adding a basic glossary

GH-2368 Minor edits

GH-2368 Adding missing READMEs and standardization.

resolving readme updates

GH-2368 Minor improvements to documentation.

Improving some readmes.

Further improvement for readmes.

Cleaned up the documentation in 'client_example' (#2468)

Update for PR

Update ACRONYMS.md to remove trivial terms

Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine.

revise 37_transpose readme

revise 36_copy readme

Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity.

Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity.

Remove references to the Tile Engine in README files across multiple examples

Refine README files by removing outdated references to the Tile Engine

* Updates based on PR feedback 1

* Updates based on PR feedback 2

* Updates based on PR feedback 3

* Updates based on PR feedback 4

* Updates based on PR feedback 5

* Updates based on PR feedback 6

* Updates based on PR feedback 7

* Updates based on PR feedback 8

* Content Modification of CK Tile Example

* Modify the ck_tile gemm config

---------

Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
2025-10-16 03:10:57 -07:00

5.2 KiB

GEMM with K-Axis Splitting (Split-K GEMM)

This example demonstrates a General Matrix-Matrix Multiplication (GEMM) implemented with a Split-K algorithm. This is a technique used to increase the available parallelism for a single, large GEMM operation, which can lead to higher performance, especially on GPUs with a very large number of compute units.

Mathematical Formulation

A standard GEMM computes the matrix product C = A \times B, where A has shape [M, K] and B has shape [K, N]. The computation is: C_{ij} = \sum_{k=0}^{K-1} A_{ik} B_{kj}

In a Split-K algorithm, the K dimension is split into S chunks of size K_split = K / S. The GEMM is then broken down into S smaller, partial GEMMs.

For each split s from 0 to S-1:

  • Let A_s be the s-th slice of A along the K-axis (shape [M, K_split]).
  • Let B_s be the s-th slice of B along the K-axis (shape [K_split, N]).
  • A partial product is computed: C_s = A_s \times B_s.

The final result C is the sum of all the partial products: C = \sum_{s=0}^{S-1} C_s = C_0 + C_1 + \dots + C_{S-1}

Algorithmic Strategy: Parallel Reduction of Partial GEMMs

The Split-K algorithm turns a single large GEMM into multiple smaller GEMMs whose results must be reduced (summed). This introduces a new axis of parallelism.

  1. Splitting the K-Dimension: The K dimension of the input matrices A and B is logically split into S parts. The S value is chosen by the kernel based on the problem size and hardware characteristics to expose a suitable amount of parallelism.

  2. Parallel Partial GEMMs: The S partial GEMMs are executed in parallel. The GPU's grid of thread blocks is now two-dimensional, mapping not only to the M and N dimensions of the output matrix C, but also to the S splits of the K dimension.

    • A thread block is assigned to compute a tile of a partial product C_s.
  3. Reduction of Partial Results: The key challenge is how to sum the partial products C_s efficiently.

    • Atomic Add: The simplest method is for each block to compute its tile of C_s and then use atomic add operations to accumulate its result directly into the final output matrix C in global memory. This is easy to implement but can suffer from high contention on the atomic operations, especially if many splits are trying to update the same memory location.
    • Two-Stage Reduction: A more robust approach involves two stages:
      • Stage 1 (Partial Products): Each of the S parallel GEMMs writes its full partial product C_s to a temporary workspace in global memory.
      • Stage 2 (Final Reduction): A separate reduction kernel is launched to sum the S partial products from the workspace into the final output matrix C.

Composable Kernel's implementation abstracts this complexity. The DeviceGemmSplitK interface handles the selection of the split factor S, the launch of the parallel partial GEMMs, and the final reduction step.

Source Code Organization

  • splitk_gemm_xdl.cpp: The main example file. It sets up a standard GEMM problem and instantiates the DeviceGemmSplitK operation.
  • ../../include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp: The high-level device interface for the Split-K GEMM. It takes an additional k_batch parameter which controls the number of splits.
  • The underlying grid-wise kernel is modified to accept a k_batch index, so that each thread block knows which slice of the A and B matrices it is responsible for. It also includes the logic for the reduction (e.g., using atomic adds).

Build and Run

Prerequisites

Ensure the Composable Kernel library is built and installed.

cd /path/to/composable_kernel/build
make -j install

Build the Example

cd /path/to/composable_kernel/example/35_splitK_gemm
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

# Run the example with default settings
./splitk_gemm_xdl

# Run with verification, data initialization, and timing
./splitk_gemm_xdl 1 2 1

When is Split-K Useful?

Split-K is not always faster than a standard GEMM. It is most beneficial in specific scenarios:

  • "Skinny" GEMMs: For GEMMs where M and N are small but K is very large (e.g., M=64, N=64, K=65536). A standard GEMM might not generate enough parallel work to fill a large GPU. By splitting the large K dimension, we create many more independent work items, improving hardware utilization.
  • Limited Shared Memory: If a standard GEMM requires a very large tile size (and thus a large amount of shared memory) to be efficient, Split-K can be an alternative. It can use smaller tiles for the partial GEMMs, reducing the shared memory footprint per block.
  • Load Balancing: It can help with load balancing on heterogeneous hardware or in complex fused scenarios.

The trade-off is the overhead of the reduction step. The performance gain from increased parallelism must outweigh the cost of either atomic operations or writing and re-reading intermediate results.