mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
222 lines
7.3 KiB
Markdown
222 lines
7.3 KiB
Markdown
[Back to supported operations](../../../include/ck/README.md)
|
|
# Composable Kernel GEMM Example
|
|
|
|
## Introduction
|
|
|
|
GEMM (General Matrix Multiplication) is a fundamental operation in linear algebra and deep learning. It computes the product of two matrices, optionally adds a bias or residual, and is the core of many neural network layers (MLPs, attention, convolutions via im2col). This example demonstrates the flexible and high-performance GEMM API provided by Composable Kernel.
|
|
|
|
---
|
|
|
|
## Theory
|
|
|
|
**Mathematical Formulation:**
|
|
$$
|
|
C = \alpha (A \times B) + \beta D
|
|
$$
|
|
- $A$: [M, K] input matrix
|
|
- $B$: [K, N] weight matrix
|
|
- $D$: [M, N] optional bias/residual
|
|
- $C$: [M, N] output
|
|
- $\alpha, \beta$: scalars (often 1.0, 0.0)
|
|
|
|
GEMM is implemented using a tiled/blocking strategy to maximize data reuse and memory bandwidth. Modern GPU implementations use matrix core/XDL/MFMA instructions for high throughput. The operation is the computational backbone for transformer attention, MLPs, CNNs (via lowering), and more.
|
|
|
|
---
|
|
|
|
## CK GEMM API Overview
|
|
|
|
CK provides a highly composable GEMM API via the `DeviceGemm` family of device operations. These are highly templated to support a wide range of data types, layouts, and fused operations.
|
|
|
|
### Template Parameters
|
|
|
|
- **ALayout** - A matrix layout (RowMajor/ColumnMajor)
|
|
- **BLayout** - B matrix layout (RowMajor/ColumnMajor)
|
|
- **CLayout** - C matrix layout (RowMajor/ColumnMajor)
|
|
- **ADataType** - A matrix data type
|
|
- **BDataType** - B matrix data type
|
|
- **CDataType** - C matrix data type
|
|
- **AElementwiseOperation** - Fused operation on tensor A before GEMM
|
|
- **BElementwiseOperation** - Fused operation on tensor B before GEMM
|
|
- **CElementwiseOperation** - Fused operation on tensor C after GEMM
|
|
|
|
For large K dimension, use `DeviceGemmSplitK` to split K across workgroups (requires zeroing output buffer due to use of AtomicAdd).
|
|
|
|
For fused operations with additional tensors, use `DeviceGemmMultipleABD` or `DeviceGemmMultipleD`:
|
|
- **DsLayout** - layouts for additional tensors
|
|
- **DsDataType** - data types for additional tensors
|
|
|
|
For `DeviceGemmMultipleABD`, pass **ALayout**, **BLayout**, **ADataType**, **BDataType** as tuples.
|
|
|
|
---
|
|
|
|
## Supported GEMM Variants
|
|
|
|
- **DeviceGemm**: Standard GEMM
|
|
- **DeviceGemmSplitK**: Split-K GEMM for large K
|
|
- **DeviceGemmMultipleABD**: Fused GEMM with multiple A/B/D tensors
|
|
- **DeviceGemmMultipleD**: Fused GEMM with multiple D tensors
|
|
|
|
---
|
|
|
|
## Supported Device Operations
|
|
|
|
- **DeviceGemmDl**: DL instructions
|
|
- **DeviceGemmDpp**: DL instructions with DPP during data load
|
|
- **DeviceGemmWmma_CShuffle**: WMMA instructions with CShuffle optimization
|
|
- **DeviceGemm_Xdl_CShuffle_LdsDirectLoad**: XDL instructions, CShuffle, direct global-to-shared load
|
|
- **DeviceGemm_Xdl_CShuffle**: XDL instructions with CShuffle
|
|
- **DeviceGemm_Xdl_CShuffleV2**: XDL instructions, optimized pipeline vs. V1
|
|
- **DeviceGemmXdlSkipBLds**: XDL, skips shared memory load for B
|
|
- **DeviceGemm_Xdl_WaveletModel_CShuffle**: XDL, CShuffle, wavelet producer/consumer
|
|
- **DeviceGemmXdl**: XDL instructions
|
|
|
|
---
|
|
|
|
## Supported Data Types and Layouts
|
|
|
|
### XDL Instruction
|
|
|
|
| |Is supported|
|
|
|-------|---|
|
|
|bf16 |✔️|
|
|
|fp16 |✔️|
|
|
|fp32 |✔️|
|
|
|int8 |✔️|
|
|
|fp8 |✔️|
|
|
|
|
### WMMA Instruction
|
|
|
|
| |Is supported|
|
|
|-------|---|
|
|
|bf16 |✔️|
|
|
|fp16 |✔️|
|
|
|fp32 |❌|
|
|
|int8 |✔️|
|
|
|fp8 |❌|
|
|
|
|
### DL Instruction
|
|
|
|
| |Is supported|
|
|
|-------|---|
|
|
|bf16 |❌|
|
|
|fp16 |✔️|
|
|
|fp32 |✔️|
|
|
|int8 |✔️|
|
|
|fp8 |❌|
|
|
|
|
---
|
|
|
|
## Supported Fused Elementwise Operations
|
|
|
|
- **B Matrix Multiply + Add + Gelu** - bf16 (int8 for B matrix)
|
|
- **B Matrix Multiply + Add** - bf16 (int8 for B matrix)
|
|
- **B Matrix Multiply + Gelu** - bf16 (int8 for B matrix)
|
|
- **B Matrix Multiply** - bf16 (int8 for B matrix)
|
|
- **Add + Add + Gelu** - fp16
|
|
- **Add + Gelu** - fp16, bf16 (int8 for B matrix) for Row/Column/Row
|
|
- **Multiply** - fp16
|
|
- **Add + Multiply** - fp16
|
|
- **Add + Relu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
|
|
- **Add + Silu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
|
|
- **Add** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row
|
|
- **Bilinear** - fp16, int8
|
|
- **Gelu** - fp16
|
|
- **Multiply + Add** - fp16 for Row/Column/Row and Row/Row/Row, fp16 (int8 for B matrix, fp32 for Bias) for Row/Column/Row and Row/Row/Row
|
|
- **Quantization** - int8
|
|
|
|
---
|
|
|
|
## GEMM V2 (Universal GEMM)
|
|
|
|
Optimized for MI300 series. Operation is called as `DeviceGemmV2` and uses similar template parameters as above.
|
|
|
|
- **ALayout**, **BLayout**, **CLayout**
|
|
- **ADataType**, **BDataType**, **CDataType**
|
|
- **AElementwiseOperation**, **BElementwiseOperation**, **CElementwiseOperation**
|
|
|
|
Split-K is supported (requires zeroing output buffer if splitK > 1).
|
|
|
|
### Device Operations
|
|
|
|
- **DeviceGemm_Xdl_CShuffleV3**: XDL with CShuffle optimization
|
|
- **DeviceGemm_Xdl_CShuffleV3R1**: XDL with CShuffle, reduction on split-K after GEMM
|
|
|
|
### Supported Types
|
|
|
|
| |Is supported|
|
|
|-------|---|
|
|
|bf16 |✔️|
|
|
|fp16 |✔️|
|
|
|fp32 |❌|
|
|
|int8 |❌|
|
|
|fp8 (C bf16)|✔️|
|
|
|fp16 (A fp8)|✔️|
|
|
|fp16 (B fp8)|✔️|
|
|
|
|
---
|
|
|
|
## Other GEMM Extensions
|
|
|
|
- **DeviceGemm_dequantB**: GEMM with dequantization (WMMA)
|
|
- **DeviceGemmMultipleD_ABScale**: GEMM with scale for A and B
|
|
- **DeviceGemmMultipleDLayernorm**: GEMM fused with layernorm
|
|
- **DeviceGemmMultipleDMultipleR**: GEMM fused with reductions and custom global reductions
|
|
- **DeviceGemmReduce**: GEMM fused with reduction
|
|
- **DeviceGemm_Streamk_V2**: Stream K with reduction instead of AtomicAdd
|
|
- **DeviceGemmStreamK**: Stream K using AtomicAdd
|
|
|
|
---
|
|
|
|
## How to Run
|
|
|
|
### Prerequisites
|
|
|
|
Please follow the instructions in the main [Build Guide](../../README.md#building-ck) section as a prerequisite to building and running this example.
|
|
|
|
### Build and run
|
|
|
|
```bash
|
|
cd composable_kernel/example/01_gemm
|
|
mkdir build && cd build
|
|
cmake -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc ..
|
|
make -j
|
|
|
|
# Example run (FP16)
|
|
./gemm_xdl_fp16 -M 4096 -N 4096 -K 4096 -v 1 -t 1
|
|
```
|
|
|
|
---
|
|
|
|
## Source Code Structure
|
|
|
|
```
|
|
example/01_gemm/
|
|
├── gemm_xdl_fp16.cpp # Main example: sets up, runs, and verifies GEMM (FP16)
|
|
├── gemm_xdl_fp32.cpp # Main example: FP32 variant
|
|
include/ck/tensor_operation/gpu/device/
|
|
│ └── device_gemm.hpp # Device-level GEMM API (templated)
|
|
include/ck/tensor_operation/gpu/device/impl/
|
|
│ └── device_gemm_xdl.hpp # XDL-based GEMM implementation
|
|
include/ck/tensor_operation/gpu/grid/
|
|
│ └── gridwise_gemm_xdl.hpp # Grid-level tiled GEMM kernel
|
|
include/ck/tensor_operation/gpu/block/
|
|
│ └── blockwise_gemm_xdl.hpp # Block-level tiled GEMM
|
|
library/reference_tensor_operation/cpu/
|
|
└── reference_gemm.hpp # CPU reference GEMM for correctness checking
|
|
```
|
|
|
|
### Key Classes and Functions
|
|
|
|
- **DeviceGemmXdl** (in `device_gemm.hpp`):
|
|
Main device API for launching GEMM kernels.
|
|
- **GridwiseGemmXdl** (in `gridwise_gemm_xdl.hpp`):
|
|
Implements the tiled/blocking GEMM kernel for the GPU grid.
|
|
- **BlockwiseGemmXdl** (in `blockwise_gemm_xdl.hpp`):
|
|
Handles block-level computation and shared memory tiling.
|
|
- **reference_gemm** (in `reference_gemm.hpp`):
|
|
CPU implementation for result verification.
|
|
|
|
---
|
|
|
|
This example is the foundation for all matrix operations in Composable Kernel and is the basis for more advanced fused and batched operations.
|