* chore(copyright) update library wide CMakeLists.txt files copyright header template * Fix build --------- Co-authored-by: Sami Remes <samremes@amd.com>
Batched GEMM with Reduction
This example demonstrates a Batched General Matrix-Matrix Multiplication (Batched GEMM) where the result of each individual GEMM in the batch is then reduced along one of its dimensions. This is a specialized fusion pattern that combines a compute-intensive operation (GEMM) with a memory-intensive one (reduction), offering significant performance benefits for specific workloads.
Mathematical Formulation
The operation performs a standard GEMM for each item in a batch, and then reduces the resulting matrix to a vector. For each batch item b from 0 to BatchCount-1:
-
GEMM Stage: A standard matrix multiplication is performed.
C_{[b]} = A_{[b]} \times B_{[b]} -
Reduction Stage: The resulting matrix
C_{[b]}is reduced along one of its dimensions (e.g., the M dimension) to produce an output vectorD_{[b]}.D_{[b], j} = \bigoplus_{i=0}^{M-1} C_{[b], i, j}
Where:
A_{[b]}is anM \times Kmatrix.B_{[b]}is aK \times Nmatrix.C_{[b]}is the intermediateM \times Nresult matrix for batchb.D_{[b]}is the final1 \times Noutput vector for batchb.\bigoplusis a binary, associative reduction operator like sum, max, or min.
The key optimization is that the intermediate matrix C_{[b]} is never written to global memory. The reduction is fused directly into the GEMM kernel.
Algorithmic Strategy: Fused GEMM and Reduction
The implementation fuses the reduction into the epilogue of a batched GEMM kernel. The batch dimension provides a natural axis for parallelism.
-
Batch Scheduling: The
BatchCountGEMM problems are distributed across the GPU's thread blocks. Each block is assigned one or more GEMMs from the batch to compute. -
Tiled GEMM Core: For each assigned GEMM, the thread block runs a standard tiled GEMM algorithm to compute the product
A_{[b]} \times B_{[b]}. The result for each tile ofC_{[b]}is accumulated in the private registers of the threads. -
Fused Reduction Epilogue: This is where the fusion occurs. Instead of writing the computed tile of
C_{[b]}to global memory, the threads use it as input for a parallel reduction.- Intra-Block Reduction: The threads within a block, which collectively hold the values for a tile of
C_{[b]}, perform a local reduction. For example, to reduce along the M dimension, threads responsible for different M-rows but the same N-column will cooperate, using fast shared memory to sum their partial results. - Inter-Block Reduction: Since multiple thread blocks may be working on different M-tiles for the same batch item, their partial reduction results must be combined. Each block writes its partial sum to a designated location in the output vector
D, using atomic operations (likeatomicAdd) to safely accumulate the final result.
- Intra-Block Reduction: The threads within a block, which collectively hold the values for a tile of
This strategy completely eliminates the global memory traffic associated with the intermediate matrix C, which is often the largest tensor in the operation. This leads to substantial savings in memory bandwidth and improved performance.
Source Code Organization
batched_gemm_reduce_xdl.cpp: The main example file. It sets up the batched GEMM problem and instantiates theDeviceBatchedGemmReduceoperation, specifying the reduction dimension and operator.../../include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp: The high-level device interface for this fused operation.../../include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_reduce_xdl_cshuffle.hpp: The grid-wise kernel that implements the fused logic. It handles the batch scheduling, the tiled GEMM, and the fused reduction epilogue with atomic operations for inter-block communication.
Build and Run
Prerequisites
Please follow the instructions in the main Build Guide section as a prerequisite to building and running this example.
Build the Example
cd /path/to/composable_kernel/example/18_batched_gemm_reduce
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
Run the Example
# Run the example with default settings
./batched_gemm_reduce_xdl
# Run with verification, data initialization, and timing
./batched_gemm_reduce_xdl 1 2 1
Applications
This fused pattern is less common than simple GEMM+Bias but is highly effective for specific algorithms.
- Gradient Computations: In some complex neural network layers, the gradient calculation might involve a matrix product followed by a summation. For example, computing the gradient with respect to a bias term often involves summing the output gradients over the batch and spatial dimensions. If the output gradient itself is the result of a GEMM, this fused kernel could be applicable.
- Custom Attention Mechanisms: While standard attention involves a
softmax, some research explores attention-like mechanisms that might use a simple sum or max reduction instead. If the query-key interaction is formulated as a batched GEMM, this kernel could compute the attention weights in a single, fused step. - Scientific Computing: Certain numerical methods, particularly in physics or signal processing, may involve performing a linear transform (GEMM) on a set of signals (a batch) and then integrating the result (a reduction).