mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
58 lines
1.9 KiB
Markdown
58 lines
1.9 KiB
Markdown
# Batched GEMM
|
|
|
|
## Theory
|
|
|
|
This example demonstrates **batched GEMM**: performing multiple independent matrix multiplications (all with the same shape) in a single kernel launch. Batched GEMM is used in multi-head attention, RNNs, and other models requiring parallel matrix multiplications.
|
|
|
|
**Mathematical Formulation:**
|
|
For $B$ batches:
|
|
$$
|
|
C_b = A_b \times B_b \quad \text{for} \quad b = 1, 2, ..., B
|
|
$$
|
|
- $A_b$: [M, K] input matrix for batch $b$
|
|
- $B_b$: [K, N] weight matrix for batch $b$
|
|
- $C_b$: [M, N] output matrix for batch $b$
|
|
|
|
**Algorithmic Background:**
|
|
- All matrices in the batch have the same shape and strides.
|
|
- The kernel launches a grid covering all batches, with each block assigned to a batch.
|
|
- Used for multi-head attention, parallel MLPs, and more.
|
|
|
|
## How to Run
|
|
|
|
### Prerequisites
|
|
|
|
Please follow the instructions in the main [Build Guide](../../README.md#building-ck) section as a prerequisite to building and running this example.
|
|
|
|
### Build and run
|
|
```bash
|
|
cd composable_kernel/example/24_batched_gemm
|
|
mkdir build && cd build
|
|
cmake -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc ..
|
|
make -j
|
|
|
|
# Example run
|
|
./batched_gemm_xdl --verify=1 --time=1
|
|
```
|
|
|
|
## Source Code Structure
|
|
|
|
### Directory Layout
|
|
```
|
|
example/24_batched_gemm/
|
|
├── batched_gemm_xdl.cpp # Main example: sets up, runs, and verifies batched GEMM
|
|
include/ck/tensor_operation/gpu/device/
|
|
│ └── device_batched_gemm_xdl.hpp # Device-level batched GEMM API
|
|
include/ck/tensor_operation/gpu/grid/
|
|
│ └── gridwise_batched_gemm_xdl.hpp # Grid-level batched GEMM kernel
|
|
```
|
|
|
|
### Key Classes and Functions
|
|
|
|
- **DeviceBatchedGemmXdl** (in `device_batched_gemm_xdl.hpp`):
|
|
Device API for batched GEMM.
|
|
- **gridwise_batched_gemm_xdl** (in `gridwise_batched_gemm_xdl.hpp`):
|
|
Implements the tiled/blocking batched GEMM kernel.
|
|
|
|
This example demonstrates how Composable Kernel supports efficient parallel matrix multiplication for batched and multi-head workloads.
|