* GH-2368 Adding a basic glossary GH-2368 Minor edits GH-2368 Adding missing READMEs and standardization. resolving readme updates GH-2368 Minor improvements to documentation. Improving some readmes. Further improvement for readmes. Cleaned up the documentation in 'client_example' (#2468) Update for PR Update ACRONYMS.md to remove trivial terms Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine. revise 37_transpose readme revise 36_copy readme Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity. Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity. Remove references to the Tile Engine in README files across multiple examples * GH-2368 Adding a basic glossary GH-2368 Minor edits GH-2368 Adding missing READMEs and standardization. resolving readme updates GH-2368 Minor improvements to documentation. Improving some readmes. Further improvement for readmes. Cleaned up the documentation in 'client_example' (#2468) Update for PR Update ACRONYMS.md to remove trivial terms Update ACRONYMS.md to provide detailed explanations for BF16 and BF8 formats Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Apply suggestion from @spolifroni-amd Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com> Update README.md to clarify CK Tile API description and remove outdated references to the Tile Engine. revise 37_transpose readme revise 36_copy readme Remove references to the Tile Engine in README files for 19_gemm_multi_d and 35_batched_transpose, and update distribution links for clarity. Remove references to the Tile Engine in multiple README files and update distribution links for consistency and clarity. Remove references to the Tile Engine in README files across multiple examples Refine README files by removing outdated references to the Tile Engine * Updates based on PR feedback 1 * Updates based on PR feedback 2 * Updates based on PR feedback 3 * Updates based on PR feedback 4 * Updates based on PR feedback 5 * Updates based on PR feedback 6 * Updates based on PR feedback 7 * Updates based on PR feedback 8 * Content Modification of CK Tile Example * Modify the ck_tile gemm config --------- Co-authored-by: AviralGoelAMD <aviral.goel@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
5.0 KiB
Batched GEMM with Bias, Elementwise Operation, and Permutation
This example demonstrates a Batched GEMM where each individual GEMM operation is fused with a bias addition, a second elementwise operation, and a final permutation of the output. This kernel is designed to accelerate layers that have a batch-parallel structure, such as the dense layers in a Transformer's feed-forward network, when they are part of a larger fused computational graph.
Mathematical Formulation
This operation performs B independent fused GEMM operations in parallel, where B is the batch count. For each batch item b from 0 to B-1:
-
GEMM Stage: A standard matrix multiplication.
C_{temp1[b]} = A_{[b]} \times B_{[b]} -
Bias Addition Stage: A bias vector
D_[b]is broadcast and added.C_{temp2[b]} = C_{temp1[b]} + D_{[b]} -
Elementwise Stage: A second elementwise operation is performed with tensor
E_[b].C_{temp3[b]} = C_{temp2[b]} \odot E_{[b]} -
Permutation Stage: The final result for the batch item is permuted.
F_{[b]} = \text{permute}(C_{temp3[b]})
All four stages for all B batch items are executed within a single kernel launch. The intermediate results are kept in registers and never written to global memory.
Distinction from Grouped Version:
- In this Batched version, all
Bproblems are uniform. They share the same dimensions (M, N, K), layouts, and permutations. The input/output tensors are accessed with a constant batch stride. - In the Grouped version (
28_grouped_gemm_bias_e_permute), each of theGproblems can have different dimensions, layouts, and strides, offering more flexibility.
Algorithmic Strategy: Batch-Parallel GEMM with Fused Epilogue
The implementation combines the scheduling strategy of Batched GEMM with the multi-stage fused epilogue.
-
Batch Scheduling: The
Bindependent problems are distributed across the GPU's thread blocks. The grid-wise kernel is designed such that each thread block is assigned to compute one of theBfused operations. -
Fused GEMM Execution: Once a thread block is assigned a batch item
b, it executes a complete fused GEMM for that item's specific data. This involves:- Calculating the base memory addresses for
A_{[b]}, B_{[b]}, D_{[b]}, E_{[b]}, andF_{[b]}using the batch index and the constant batch stride. - Executing a standard tiled GEMM for
A_{[b]} \times B_{[b]}, accumulating the result in registers. - Executing the fused epilogue:
- Load the bias
D_[b]and add it. - Load the elementwise tensor
E_[b]and apply the operation. - Calculate the permuted destination coordinates and write the final result to
F_{[b].
- Load the bias
- Calculating the base memory addresses for
This approach is extremely efficient when the batch size B is large enough to saturate the GPU's parallelism.
Source Code Organization
batched_gemm_bias_e_permute_xdl.cpp: The main example file. It sets up the batched problem, defining the batch size, strides, and the single permutation rule that applies to all batch items. It then instantiates theDeviceBatchedGemmBiasEPermuteoperation.../../include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_bias_e_permute_impl.hpp: The high-level device interface for this specific fused operation.- The underlying grid-wise kernel contains the logic to map thread blocks to batch items (
block_to_batch) and then execute the full fused GEMM pipeline for the assigned item.
Build and Run
Prerequisites
Ensure the Composable Kernel library is built and installed.
cd /path/to/composable_kernel/build
make -j install
Build the Example
cd /path/to/composable_kernel/example/29_batched_gemm_bias_e_permute
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
Run the Example
# Run the example with default settings
./batched_gemm_bias_e_permute_xdl
# Run with verification, data initialization, and timing
./batched_gemm_bias_e_permute_xdl 1 2 1
Applications
This kernel is ideal for optimizing the feed-forward network (FFN) block in a Transformer, especially when layout transformations are needed between layers.
A typical Transformer FFN block is:
FFN(X) = Linear_2(ReLU(Linear_1(X)))
Linear_1is a GEMM.ReLUis an elementwise activation.Linear_2is another GEMM.
Sometimes, for performance reasons (e.g., to align with a subsequent layer's expected input layout), the output of the FFN needs to be permuted. This kernel could fuse the Linear_2 GEMM with its bias, a subsequent elementwise operation (if any), and the final permutation, all while operating on a batch of input sequences. This avoids multiple kernel launches and saves significant memory bandwidth, leading to faster model execution.