Files
composable_kernel/example/01_gemm/README.md
2025-10-16 10:13:27 +00:00

222 lines
7.3 KiB
Markdown

[Back to supported operations](../../../include/ck/README.md)
# Composable Kernel GEMM Example
## Introduction
GEMM (General Matrix Multiplication) is a fundamental operation in linear algebra and deep learning. It computes the product of two matrices, optionally adds a bias or residual, and is the core of many neural network layers (MLPs, attention, convolutions via im2col). This example demonstrates the flexible and high-performance GEMM API provided by Composable Kernel.
---
## Theory
**Mathematical Formulation:**
$$
C = \alpha (A \times B) + \beta D
$$
- $A$: [M, K] input matrix
- $B$: [K, N] weight matrix
- $D$: [M, N] optional bias/residual
- $C$: [M, N] output
- $\alpha, \beta$: scalars (often 1.0, 0.0)
GEMM is implemented using a tiled/blocking strategy to maximize data reuse and memory bandwidth. Modern GPU implementations use matrix core/XDL/MFMA instructions for high throughput. The operation is the computational backbone for transformer attention, MLPs, CNNs (via lowering), and more.
---
## CK GEMM API Overview
CK provides a highly composable GEMM API via the `DeviceGemm` family of device operations. These are highly templated to support a wide range of data types, layouts, and fused operations.
### Template Parameters
- **ALayout** - A matrix layout (RowMajor/ColumnMajor)
- **BLayout** - B matrix layout (RowMajor/ColumnMajor)
- **CLayout** - C matrix layout (RowMajor/ColumnMajor)
- **ADataType** - A matrix data type
- **BDataType** - B matrix data type
- **CDataType** - C matrix data type
- **AElementwiseOperation** - Fused operation on tensor A before GEMM
- **BElementwiseOperation** - Fused operation on tensor B before GEMM
- **CElementwiseOperation** - Fused operation on tensor C after GEMM
For large K dimension, use `DeviceGemmSplitK` to split K across workgroups (requires zeroing output buffer due to use of AtomicAdd).
For fused operations with additional tensors, use `DeviceGemmMultipleABD` or `DeviceGemmMultipleD`:
- **DsLayout** - layouts for additional tensors
- **DsDataType** - data types for additional tensors
For `DeviceGemmMultipleABD`, pass **ALayout**, **BLayout**, **ADataType**, **BDataType** as tuples.
---
## Supported GEMM Variants
- **DeviceGemm**: Standard GEMM
- **DeviceGemmSplitK**: Split-K GEMM for large K
- **DeviceGemmMultipleABD**: Fused GEMM with multiple A/B/D tensors
- **DeviceGemmMultipleD**: Fused GEMM with multiple D tensors
---
## Supported Device Operations
- **DeviceGemmDl**: DL instructions
- **DeviceGemmDpp**: DL instructions with DPP during data load
- **DeviceGemmWmma_CShuffle**: WMMA instructions with CShuffle optimization
- **DeviceGemm_Xdl_CShuffle_LdsDirectLoad**: XDL instructions, CShuffle, direct global-to-shared load
- **DeviceGemm_Xdl_CShuffle**: XDL instructions with CShuffle
- **DeviceGemm_Xdl_CShuffleV2**: XDL instructions, optimized pipeline vs. V1
- **DeviceGemmXdlSkipBLds**: XDL, skips shared memory load for B
- **DeviceGemm_Xdl_WaveletModel_CShuffle**: XDL, CShuffle, wavelet producer/consumer
- **DeviceGemmXdl**: XDL instructions
---
## Supported Data Types and Layouts
### XDL Instruction
| |Is supported|
|-------|---|
|bf16 |✔️|
|fp16 |✔️|
|fp32 |✔️|
|int8 |✔️|
|fp8 |✔️|
### WMMA Instruction
| |Is supported|
|-------|---|
|bf16 |✔️|
|fp16 |✔️|
|fp32 |❌|
|int8 |✔️|
|fp8 |❌|
### DL Instruction
| |Is supported|
|-------|---|
|bf16 |❌|
|fp16 |✔️|
|fp32 |✔️|
|int8 |✔️|
|fp8 |❌|
---
## Supported Fused Elementwise Operations
- **B Matrix Multiply + Add + Gelu** - bf16 (int8 for B matrix)
- **B Matrix Multiply + Add** - bf16 (int8 for B matrix)
- **B Matrix Multiply + Gelu** - bf16 (int8 for B matrix)
- **B Matrix Multiply** - bf16 (int8 for B matrix)
- **Add + Add + Gelu** - fp16
- **Add + Gelu** - fp16, bf16 (int8 for B matrix) for Row/Column/Row
- **Multiply** - fp16
- **Add + Multiply** - fp16
- **Add + Relu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
- **Add + Silu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
- **Add** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
- **Bilinear** - fp16, int8
- **Gelu** - fp16
- **Multiply + Add** - fp16 for Row/Column/Row and Row/Row/Row, fp16 (int8 for B matrix, fp32 for Bias) for Row/Column/Row and Row/Row/Row
- **Quantization** - int8
---
## GEMM V2 (Universal GEMM)
Optimized for MI300 series. Operation is called as `DeviceGemmV2` and uses similar template parameters as above.
- **ALayout**, **BLayout**, **CLayout**
- **ADataType**, **BDataType**, **CDataType**
- **AElementwiseOperation**, **BElementwiseOperation**, **CElementwiseOperation**
Split-K is supported (requires zeroing output buffer if splitK > 1).
### Device Operations
- **DeviceGemm_Xdl_CShuffleV3**: XDL with CShuffle optimization
- **DeviceGemm_Xdl_CShuffleV3R1**: XDL with CShuffle, reduction on split-K after GEMM
### Supported Types
| |Is supported|
|-------|---|
|bf16 |✔️|
|fp16 |✔️|
|fp32 |❌|
|int8 |❌|
|fp8 (C bf16)|✔️|
|fp16 (A fp8)|✔️|
|fp16 (B fp8)|✔️|
---
## Other GEMM Extensions
- **DeviceGemm_dequantB**: GEMM with dequantization (WMMA)
- **DeviceGemmMultipleD_ABScale**: GEMM with scale for A and B
- **DeviceGemmMultipleDLayernorm**: GEMM fused with layernorm
- **DeviceGemmMultipleDMultipleR**: GEMM fused with reductions and custom global reductions
- **DeviceGemmReduce**: GEMM fused with reduction
- **DeviceGemm_Streamk_V2**: Stream K with reduction instead of AtomicAdd
- **DeviceGemmStreamK**: Stream K using AtomicAdd
---
## How to Run
### Prerequisites
Please follow the instructions in the main [Build Guide](../../README.md#building-ck) section as a prerequisite to building and running this example.
### Build and run
```bash
cd composable_kernel/example/01_gemm
mkdir build && cd build
cmake -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc ..
make -j
# Example run (FP16)
./gemm_xdl_fp16 -M 4096 -N 4096 -K 4096 -v 1 -t 1
```
---
## Source Code Structure
```
example/01_gemm/
├── gemm_xdl_fp16.cpp # Main example: sets up, runs, and verifies GEMM (FP16)
├── gemm_xdl_fp32.cpp # Main example: FP32 variant
include/ck/tensor_operation/gpu/device/
│ └── device_gemm.hpp # Device-level GEMM API (templated)
include/ck/tensor_operation/gpu/device/impl/
│ └── device_gemm_xdl.hpp # XDL-based GEMM implementation
include/ck/tensor_operation/gpu/grid/
│ └── gridwise_gemm_xdl.hpp # Grid-level tiled GEMM kernel
include/ck/tensor_operation/gpu/block/
│ └── blockwise_gemm_xdl.hpp # Block-level tiled GEMM
library/reference_tensor_operation/cpu/
└── reference_gemm.hpp # CPU reference GEMM for correctness checking
```
### Key Classes and Functions
- **DeviceGemmXdl** (in `device_gemm.hpp`):
Main device API for launching GEMM kernels.
- **GridwiseGemmXdl** (in `gridwise_gemm_xdl.hpp`):
Implements the tiled/blocking GEMM kernel for the GPU grid.
- **BlockwiseGemmXdl** (in `blockwise_gemm_xdl.hpp`):
Handles block-level computation and shared memory tiling.
- **reference_gemm** (in `reference_gemm.hpp`):
CPU implementation for result verification.
---
This example is the foundation for all matrix operations in Composable Kernel and is the basis for more advanced fused and batched operations.