* chore(copyright) update library wide CMakeLists.txt files copyright header template * Fix build --------- Co-authored-by: Sami Remes <samremes@amd.com>
Split-K GEMM with Bias, Elementwise Operation, and Permutation
This example demonstrates a highly complex fusion: a Split-K GEMM where the final result is fused with a bias addition, a second elementwise operation, and a final permutation. This kernel combines the parallelism-enhancing Split-K strategy with a multi-stage epilogue, making it suitable for accelerating very large or "skinny" GEMMs that are part of a more complex computational graph.
Mathematical Formulation
The operation first computes a GEMM using the Split-K algorithm and then applies a sequence of fused operations.
-
Split-K GEMM Stage: The matrix multiplication
C_{temp1} = A \times Bis computed by splitting theKdimension intoSchunks and summing the partial products.C_{temp1} = \sum_{s=0}^{S-1} (A_s \times B_s) -
Bias Addition Stage: A bias vector
Dis broadcast and added.C_{temp2} = C_{temp1} + D -
Elementwise Stage: A second elementwise operation is performed with tensor
E.C_{temp3} = C_{temp2} \odot E -
Permutation Stage: The final result is permuted.
F = \text{permute}(C_{temp3})
The key is that the reduction (summation) of the partial GEMM products is fused with the entire epilogue chain (Bias, E-wise, Permute).
Algorithmic Strategy: Split-K with a Fused Reduction Epilogue
The implementation combines the Split-K algorithm with the multi-stage fused epilogue seen in previous examples.
-
Splitting the K-Dimension: The
Kdimension is logically split intoSparts to createSparallel partial GEMM problems. -
Parallel Partial GEMMs: The
Spartial GEMMs are executed in parallel across the GPU's thread blocks. A thread block is assigned to compute a tile of a partial productC_s. -
Fused Reduction and Epilogue: The method for reducing the partial sums and applying the epilogue is critical.
- Workspace Approach: A common strategy is to use a temporary workspace in global memory.
- Stage 1 (Partial Products): Each of the
Sparallel GEMMs computes its partial productC_sand writes it to a unique slice of a temporary workspace tensor. - Stage 2 (Reduce + Epilogue): A second, specialized kernel is launched. This kernel reads the
Spartial products from the workspace, reduces (sums) them on-the-fly, and then immediately applies the full Bias-E-Permute epilogue before writing the final resultFto memory.
- Stage 1 (Partial Products): Each of the
- Atomic-based Approach: For some data types and operations, it's possible to perform the reduction using atomic operations. The first block to arrive at an output element would compute its partial result, apply the epilogue, and write it out. Subsequent blocks would compute their partial results, read the intermediate value from the output buffer, add their contribution, and then atomically write the new sum back. This is more complex and often less performant due to atomic contention.
- Workspace Approach: A common strategy is to use a temporary workspace in global memory.
Composable Kernel's implementation abstracts this complexity, providing a single device-level operation that manages the workspace, the two stages, and the complex epilogue.
Source Code Organization
splitk_gemm_bias_e_permute_xdl.cpp: The main example file. It sets up the GEMM problem, the bias and elementwise tensors, the permutation, and instantiates theDeviceSplitkGemmBiasEPermuteoperation.- The device-level interface and underlying kernels are highly specialized. They manage the Split-K parameter, the workspace allocation (if needed), and the two-stage execution process, combining the logic from
DeviceGemmSplitKandDeviceGemmBiasEPermute.
Build and Run
Prerequisites
Ensure the Composable Kernel library is built and installed.
cd /path/to/composable_kernel/build
make -j install
Build the Example
cd /path/to/composable_kernel/example/43_splitk_gemm_bias_e_permute
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
Run the Example
# Run the example with default settings
./splitk_gemm_bias_e_permute_xdl
# Run with verification, data initialization, and timing
./splitk_gemm_bias_e_permute_xdl 1 2 1
Applications
This highly specialized kernel is useful when a very large GEMM (that would benefit from Split-K) is immediately followed by a series of operations that can be fused.
- Large Feed-Forward Networks: In a Transformer with a very large hidden dimension, the GEMMs in the FFN block might become "skinny" (large K, smaller M/N). If this FFN is also fused with residual connections (bias/add) and layout permutations, this kernel could be a perfect fit, offering both the parallelism benefits of Split-K and the memory bandwidth savings of the fused epilogue.
- Final Classifier Layers: The final layer of a large classification model is often a very large GEMM. If this layer's output needs to be reshaped or post-processed, this kernel could fuse those operations directly into the Split-K GEMM.
This example showcases the extreme composability of the library, allowing for the creation of highly tailored, high-performance kernels that combine different algorithmic strategies (like Split-K) with deep fusion.