* GH-2368 Adding a basic glossary GH-2368 Minor edits GH-2368 Adding missing READMEs and standardization. resolving readme updates GH-2368 Minor improvements to documentation. Improving some readmes. Further improvement for readmes. Cleaned up the documentation in 'client_example' (#2468) Update for PR Update ACRONYMS.md to remove trivial terms Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine. revise 37_transpose readme revise 36_copy readme Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity. Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity. Remove references to the Tile Engine in README files across multiple examples * GH-2368 Adding a basic glossary GH-2368 Minor edits GH-2368 Adding missing READMEs and standardization. resolving readme updates GH-2368 Minor improvements to documentation. Improving some readmes. Further improvement for readmes. Cleaned up the documentation in 'client_example' (#2468) Update for PR Update ACRONYMS.md to remove trivial terms Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine. revise 37_transpose readme revise 36_copy readme Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity. Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity. Remove references to the Tile Engine in README files across multiple examples Refine README files by removing outdated references to the Tile Engine * Updates based on PR feedback 1 * Updates based on PR feedback 2 * Updates based on PR feedback 3 * Updates based on PR feedback 4 * Updates based on PR feedback 5 * Updates based on PR feedback 6 * Updates based on PR feedback 7 * Updates based on PR feedback 8 * Content Modification of CK Tile Example * Modify the ck_tile gemm config --------- Co-authored-by: AviralGoelAMD <aviral.goel@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
5.5 KiB
Batched GEMM with Reduction
This example demonstrates a Batched General Matrix-Matrix Multiplication (Batched GEMM) where the result of each individual GEMM in the batch is then reduced along one of its dimensions. This is a specialized fusion pattern that combines a compute-intensive operation (GEMM) with a memory-intensive one (reduction), offering significant performance benefits for specific workloads.
Mathematical Formulation
The operation performs a standard GEMM for each item in a batch, and then reduces the resulting matrix to a vector. For each batch item b from 0 to BatchCount-1:
-
GEMM Stage: A standard matrix multiplication is performed.
C_{[b]} = A_{[b]} \times B_{[b]} -
Reduction Stage: The resulting matrix
C_{[b]}is reduced along one of its dimensions (e.g., the M dimension) to produce an output vectorD_{[b]}.D_{[b], j} = \bigoplus_{i=0}^{M-1} C_{[b], i, j}
Where:
A_{[b]}is anM \times Kmatrix.B_{[b]}is aK \times Nmatrix.C_{[b]}is the intermediateM \times Nresult matrix for batchb.D_{[b]}is the final1 \times Noutput vector for batchb.\bigoplusis a binary, associative reduction operator like sum, max, or min.
The key optimization is that the intermediate matrix C_{[b]} is never written to global memory. The reduction is fused directly into the GEMM kernel.
Algorithmic Strategy: Fused GEMM and Reduction
The implementation fuses the reduction into the epilogue of a batched GEMM kernel. The batch dimension provides a natural axis for parallelism.
-
Batch Scheduling: The
BatchCountGEMM problems are distributed across the GPU's thread blocks. Each block is assigned one or more GEMMs from the batch to compute. -
Tiled GEMM Core: For each assigned GEMM, the thread block runs a standard tiled GEMM algorithm to compute the product
A_{[b]} \times B_{[b]}. The result for each tile ofC_{[b]}is accumulated in the private registers of the threads. -
Fused Reduction Epilogue: This is where the fusion occurs. Instead of writing the computed tile of
C_{[b]}to global memory, the threads use it as input for a parallel reduction.- Intra-Block Reduction: The threads within a block, which collectively hold the values for a tile of
C_{[b]}, perform a local reduction. For example, to reduce along the M dimension, threads responsible for different M-rows but the same N-column will cooperate, using fast shared memory to sum their partial results. - Inter-Block Reduction: Since multiple thread blocks may be working on different M-tiles for the same batch item, their partial reduction results must be combined. Each block writes its partial sum to a designated location in the output vector
D, using atomic operations (likeatomicAdd) to safely accumulate the final result.
- Intra-Block Reduction: The threads within a block, which collectively hold the values for a tile of
This strategy completely eliminates the global memory traffic associated with the intermediate matrix C, which is often the largest tensor in the operation. This leads to substantial savings in memory bandwidth and improved performance.
Source Code Organization
batched_gemm_reduce_xdl.cpp: The main example file. It sets up the batched GEMM problem and instantiates theDeviceBatchedGemmReduceoperation, specifying the reduction dimension and operator.../../include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp: The high-level device interface for this fused operation.../../include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_reduce_xdl_cshuffle.hpp: The grid-wise kernel that implements the fused logic. It handles the batch scheduling, the tiled GEMM, and the fused reduction epilogue with atomic operations for inter-block communication.
Build and Run
Prerequisites
Please follow the instructions in the main Build Guide section as a prerequisite to building and running this example.
Build the Example
cd /path/to/composable_kernel/example/18_batched_gemm_reduce
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
Run the Example
# Run the example with default settings
./batched_gemm_reduce_xdl
# Run with verification, data initialization, and timing
./batched_gemm_reduce_xdl 1 2 1
Applications
This fused pattern is less common than simple GEMM+Bias but is highly effective for specific algorithms.
- Gradient Computations: In some complex neural network layers, the gradient calculation might involve a matrix product followed by a summation. For example, computing the gradient with respect to a bias term often involves summing the output gradients over the batch and spatial dimensions. If the output gradient itself is the result of a GEMM, this fused kernel could be applicable.
- Custom Attention Mechanisms: While standard attention involves a
softmax, some research explores attention-like mechanisms that might use a simple sum or max reduction instead. If the query-key interaction is formulated as a batched GEMM, this kernel could compute the attention weights in a single, fused step. - Scientific Computing: Certain numerical methods, particularly in physics or signal processing, may involve performing a linear transform (GEMM) on a set of signals (a batch) and then integrating the result (a reduction).