mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
* chore(copyright) update library wide CMakeLists.txt files copyright header template * Fix build --------- Co-authored-by: Sami Remes <samremes@amd.com>
Batched GEMM-Scale-Softmax-GEMM: Fused Attention
Theory
This example demonstrates the fused attention mechanism used in transformer models, implementing the sequence: batched Q×K^T → scaling → softmax → ×V in a single kernel. This pattern is critical for efficient transformer inference and training.
Mathematical Formulation:
- Attention:
\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Q: [B, H, N, d_k] queriesK: [B, H, N, d_k] keysV: [B, H, N, d_v] valuesO: [B, H, N, d_v] output
Algorithmic Background:
- Computes Q×K^T, scales by
1/\sqrt{d_k}, applies softmax, then multiplies by V. - Uses numerically stable softmax and memory-efficient tiling.
- Used in multi-head attention and transformer blocks.
How to Run
Prerequisites
Please follow the instructions in the main Build Guide section as a prerequisite to building and running this example.
Build and run
cd composable_kernel/example/32_batched_gemm_scale_softmax_gemm
mkdir build && cd build
cmake -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc ..
make -j
# Example run
./batched_gemm_scale_softmax_gemm_xdl --batch=32 --heads=12 --seq_len=512 --head_dim=64 --verify=1 --time=1
Source Code Structure
Directory Layout
example/32_batched_gemm_scale_softmax_gemm/
├── batched_gemm_scale_softmax_gemm_xdl.cpp # Main example: sets up, runs, and verifies fused attention
include/ck/tensor_operation/gpu/device/
│ └── device_batched_gemm_scale_softmax_gemm.hpp # Device-level fused attention API
include/ck/tensor_operation/gpu/device/impl/
│ └── device_batched_attention_impl.hpp # Attention-specific implementation
│ └── device_online_softmax_impl.hpp # Online softmax implementation
include/ck/tensor_operation/gpu/grid/
│ └── gridwise_batched_gemm_softmax.hpp # Grid-level fused attention kernel
│ └── gridwise_online_softmax.hpp # Grid-level online softmax
Key Classes and Functions
- DeviceBatchedGemmScaleSoftmaxGemm (in
device_batched_gemm_scale_softmax_gemm.hpp):
Device API for fused attention. - gridwise_batched_gemm_softmax (in
gridwise_batched_gemm_softmax.hpp):
Implements the tiled/blocking fused attention kernel. - gridwise_online_softmax (in
gridwise_online_softmax.hpp):
Implements numerically stable, memory-efficient softmax.
This example demonstrates how Composable Kernel implements efficient, fused attention for transformer and large language models.