mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
79 lines
5.5 KiB
Markdown
79 lines
5.5 KiB
Markdown
# Batched GEMM with Reduction
|
|
|
|
This example demonstrates a Batched General Matrix-Matrix Multiplication (Batched GEMM) where the result of each individual GEMM in the batch is then reduced along one of its dimensions. This is a specialized fusion pattern that combines a compute-intensive operation (GEMM) with a memory-intensive one (reduction), offering significant performance benefits for specific workloads.
|
|
|
|
## Mathematical Formulation
|
|
|
|
The operation performs a standard GEMM for each item in a batch, and then reduces the resulting matrix to a vector. For each batch item `b` from `0` to `BatchCount-1`:
|
|
|
|
1. **GEMM Stage**: A standard matrix multiplication is performed.
|
|
$C_{[b]} = A_{[b]} \times B_{[b]}$
|
|
|
|
2. **Reduction Stage**: The resulting matrix $C_{[b]}$ is reduced along one of its dimensions (e.g., the M dimension) to produce an output vector $D_{[b]}$.
|
|
$D_{[b], j} = \bigoplus_{i=0}^{M-1} C_{[b], i, j}$
|
|
|
|
Where:
|
|
- $A_{[b]}$ is an $M \times K$ matrix.
|
|
- $B_{[b]}$ is a $K \times N$ matrix.
|
|
- $C_{[b]}$ is the intermediate $M \times N$ result matrix for batch `b`.
|
|
- $D_{[b]}$ is the final $1 \times N$ output vector for batch `b`.
|
|
- $\bigoplus$ is a binary, associative reduction operator like sum, max, or min.
|
|
|
|
The key optimization is that the intermediate matrix $C_{[b]}$ is never written to global memory. The reduction is fused directly into the GEMM kernel.
|
|
|
|
## Algorithmic Strategy: Fused GEMM and Reduction
|
|
|
|
The implementation fuses the reduction into the epilogue of a batched GEMM kernel. The batch dimension provides a natural axis for parallelism.
|
|
|
|
1. **Batch Scheduling**: The `BatchCount` GEMM problems are distributed across the GPU's thread blocks. Each block is assigned one or more GEMMs from the batch to compute.
|
|
|
|
2. **Tiled GEMM Core**: For each assigned GEMM, the thread block runs a standard tiled GEMM algorithm to compute the product $A_{[b]} \times B_{[b]}$. The result for each tile of $C_{[b]}$ is accumulated in the private registers of the threads.
|
|
|
|
3. **Fused Reduction Epilogue**: This is where the fusion occurs. Instead of writing the computed tile of $C_{[b]}$ to global memory, the threads use it as input for a parallel reduction.
|
|
- **Intra-Block Reduction**: The threads within a block, which collectively hold the values for a tile of $C_{[b]}$, perform a local reduction. For example, to reduce along the M dimension, threads responsible for different M-rows but the same N-column will cooperate, using fast shared memory to sum their partial results.
|
|
- **Inter-Block Reduction**: Since multiple thread blocks may be working on different M-tiles for the same batch item, their partial reduction results must be combined. Each block writes its partial sum to a designated location in the output vector `D`, using atomic operations (like `atomicAdd`) to safely accumulate the final result.
|
|
|
|
This strategy completely eliminates the global memory traffic associated with the intermediate matrix `C`, which is often the largest tensor in the operation. This leads to substantial savings in memory bandwidth and improved performance.
|
|
|
|
## Source Code Organization
|
|
|
|
- [`batched_gemm_reduce_xdl.cpp`](./batched_gemm_reduce_xdl.cpp): The main example file. It sets up the batched GEMM problem and instantiates the `DeviceBatchedGemmReduce` operation, specifying the reduction dimension and operator.
|
|
- [`../../include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp`](../../include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp): The high-level device interface for this fused operation.
|
|
- [`../../include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_reduce_xdl_cshuffle.hpp`](../../include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_reduce_xdl_cshuffle.hpp): The grid-wise kernel that implements the fused logic. It handles the batch scheduling, the tiled GEMM, and the fused reduction epilogue with atomic operations for inter-block communication.
|
|
|
|
## Build and Run
|
|
|
|
### Prerequisites
|
|
|
|
Please follow the instructions in the main [Build Guide](../../README.md#building-ck) section as a prerequisite to building and running this example.
|
|
|
|
### Build the Example
|
|
```bash
|
|
cd /path/to/composable_kernel/example/18_batched_gemm_reduce
|
|
mkdir build && cd build
|
|
|
|
cmake \
|
|
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
|
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
|
|
..
|
|
|
|
make -j
|
|
```
|
|
|
|
### Run the Example
|
|
```bash
|
|
# Run the example with default settings
|
|
./batched_gemm_reduce_xdl
|
|
|
|
# Run with verification, data initialization, and timing
|
|
./batched_gemm_reduce_xdl 1 2 1
|
|
```
|
|
|
|
## Applications
|
|
|
|
This fused pattern is less common than simple GEMM+Bias but is highly effective for specific algorithms.
|
|
|
|
- **Gradient Computations**: In some complex neural network layers, the gradient calculation might involve a matrix product followed by a summation. For example, computing the gradient with respect to a bias term often involves summing the output gradients over the batch and spatial dimensions. If the output gradient itself is the result of a GEMM, this fused kernel could be applicable.
|
|
- **Custom Attention Mechanisms**: While standard attention involves a `softmax`, some research explores attention-like mechanisms that might use a simple sum or max reduction instead. If the query-key interaction is formulated as a batched GEMM, this kernel could compute the attention weights in a single, fused step.
|
|
- **Scientific Computing**: Certain numerical methods, particularly in physics or signal processing, may involve performing a linear transform (GEMM) on a set of signals (a batch) and then integrating the result (a reduction).
|