Files
composable_kernel/example/18_batched_gemm_reduce/README.md
Vidyasagar Ananthan 92c67a824f [DOCS] Documentation Addition (Readme updates) (#2495)
* GH-2368 Adding a basic glossary

GH-2368 Minor edits

GH-2368 Adding missing READMEs and standardization.

resolving readme updates

GH-2368 Minor improvements to documentation.

Improving some readmes.

Further improvement for readmes.

Cleaned up the documentation in 'client_example' (#2468)

Update for PR

Update ACRONYMS.md to remove trivial terms

Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine.

revise 37_transpose readme

revise 36_copy readme

Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity.

Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity.

Remove references to the Tile Engine in README files across multiple examples

* GH-2368 Adding a basic glossary

GH-2368 Minor edits

GH-2368 Adding missing READMEs and standardization.

resolving readme updates

GH-2368 Minor improvements to documentation.

Improving some readmes.

Further improvement for readmes.

Cleaned up the documentation in 'client_example' (#2468)

Update for PR

Update ACRONYMS.md to remove trivial terms

Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Apply suggestion from @spolifroni-amd

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine.

revise 37_transpose readme

revise 36_copy readme

Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity.

Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity.

Remove references to the Tile Engine in README files across multiple examples

Refine README files by removing outdated references to the Tile Engine

* Updates based on PR feedback 1

* Updates based on PR feedback 2

* Updates based on PR feedback 3

* Updates based on PR feedback 4

* Updates based on PR feedback 5

* Updates based on PR feedback 6

* Updates based on PR feedback 7

* Updates based on PR feedback 8

* Content Modification of CK Tile Example

* Modify the ck_tile gemm config

---------

Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
2025-10-16 03:10:57 -07:00

5.5 KiB

Batched GEMM with Reduction

This example demonstrates a Batched General Matrix-Matrix Multiplication (Batched GEMM) where the result of each individual GEMM in the batch is then reduced along one of its dimensions. This is a specialized fusion pattern that combines a compute-intensive operation (GEMM) with a memory-intensive one (reduction), offering significant performance benefits for specific workloads.

Mathematical Formulation

The operation performs a standard GEMM for each item in a batch, and then reduces the resulting matrix to a vector. For each batch item b from 0 to BatchCount-1:

  1. GEMM Stage: A standard matrix multiplication is performed. C_{[b]} = A_{[b]} \times B_{[b]}

  2. Reduction Stage: The resulting matrix C_{[b]} is reduced along one of its dimensions (e.g., the M dimension) to produce an output vector D_{[b]}. D_{[b], j} = \bigoplus_{i=0}^{M-1} C_{[b], i, j}

Where:

  • A_{[b]} is an M \times K matrix.
  • B_{[b]} is a K \times N matrix.
  • C_{[b]} is the intermediate M \times N result matrix for batch b.
  • D_{[b]} is the final 1 \times N output vector for batch b.
  • \bigoplus is a binary, associative reduction operator like sum, max, or min.

The key optimization is that the intermediate matrix C_{[b]} is never written to global memory. The reduction is fused directly into the GEMM kernel.

Algorithmic Strategy: Fused GEMM and Reduction

The implementation fuses the reduction into the epilogue of a batched GEMM kernel. The batch dimension provides a natural axis for parallelism.

  1. Batch Scheduling: The BatchCount GEMM problems are distributed across the GPU's thread blocks. Each block is assigned one or more GEMMs from the batch to compute.

  2. Tiled GEMM Core: For each assigned GEMM, the thread block runs a standard tiled GEMM algorithm to compute the product A_{[b]} \times B_{[b]}. The result for each tile of C_{[b]} is accumulated in the private registers of the threads.

  3. Fused Reduction Epilogue: This is where the fusion occurs. Instead of writing the computed tile of C_{[b]} to global memory, the threads use it as input for a parallel reduction.

    • Intra-Block Reduction: The threads within a block, which collectively hold the values for a tile of C_{[b]}, perform a local reduction. For example, to reduce along the M dimension, threads responsible for different M-rows but the same N-column will cooperate, using fast shared memory to sum their partial results.
    • Inter-Block Reduction: Since multiple thread blocks may be working on different M-tiles for the same batch item, their partial reduction results must be combined. Each block writes its partial sum to a designated location in the output vector D, using atomic operations (like atomicAdd) to safely accumulate the final result.

This strategy completely eliminates the global memory traffic associated with the intermediate matrix C, which is often the largest tensor in the operation. This leads to substantial savings in memory bandwidth and improved performance.

Source Code Organization

Build and Run

Prerequisites

Please follow the instructions in the main Build Guide section as a prerequisite to building and running this example.

Build the Example

cd /path/to/composable_kernel/example/18_batched_gemm_reduce
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

# Run the example with default settings
./batched_gemm_reduce_xdl

# Run with verification, data initialization, and timing
./batched_gemm_reduce_xdl 1 2 1

Applications

This fused pattern is less common than simple GEMM+Bias but is highly effective for specific algorithms.

  • Gradient Computations: In some complex neural network layers, the gradient calculation might involve a matrix product followed by a summation. For example, computing the gradient with respect to a bias term often involves summing the output gradients over the batch and spatial dimensions. If the output gradient itself is the result of a GEMM, this fused kernel could be applicable.
  • Custom Attention Mechanisms: While standard attention involves a softmax, some research explores attention-like mechanisms that might use a simple sum or max reduction instead. If the query-key interaction is formulated as a batched GEMM, this kernel could compute the attention weights in a single, fused step.
  • Scientific Computing: Certain numerical methods, particularly in physics or signal processing, may involve performing a linear transform (GEMM) on a set of signals (a batch) and then integrating the result (a reduction).