* chore(copyright) update library wide CMakeLists.txt files copyright header template * Fix build --------- Co-authored-by: Sami Remes <samremes@amd.com>
Grouped GEMM with Bias, Elementwise Operation, and Permutation
This example demonstrates a highly complex and specialized fusion: a Grouped GEMM where each individual GEMM operation is fused with a bias addition, a second elementwise operation, and a final permutation of the output. This kernel is designed to accelerate layers that have a group-parallel structure, such as depthwise separable convolutions or multi-head attention, when they are part of a larger fused computational graph.
Mathematical Formulation
This operation performs G independent fused GEMM operations in parallel, where G is the group count. For each group g from 0 to G-1:
-
GEMM Stage: A standard matrix multiplication.
C_{temp1[g]} = A_{[g]} \times B_{[g]} -
Bias Addition Stage: A bias vector
D_[g]is broadcast and added.C_{temp2[g]} = C_{temp1[g]} + D_{[g]} -
Elementwise Stage: A second elementwise operation is performed with tensor
E_[g].C_{temp3[g]} = C_{temp2[g]} \odot E_{[g]} -
Permutation Stage: The final result for the group is permuted.
F_{[g]} = \text{permute}(C_{temp3[g]})
All four stages for all G groups are executed within a single kernel launch. The intermediate results are kept in registers and never written to global memory.
Algorithmic Strategy: Group-Parallel GEMM with Fused Epilogue
The implementation combines the scheduling strategy of Grouped GEMM with the multi-stage fused epilogue seen in 25_gemm_bias_e_permute.
-
Group Scheduling: The
Gindependent problems are distributed across the GPU's thread blocks. The grid-wise kernel is designed such that each thread block is assigned to compute one of theGfused operations. -
Fused GEMM Execution: Once a thread block is assigned a group
g, it executes a complete fused GEMM for that group's specific data. This involves:- Calculating the base memory addresses for
A_{[g]}, B_{[g]}, D_{[g]}, E_{[g]}, andF_{[g]}using the group index and the problem description for that group. - Executing a standard tiled GEMM for
A_{[g]} \times B_{[g]}, accumulating the result in registers. - Executing the fused epilogue:
- Load the bias
D_[g]and add it. - Load the elementwise tensor
E_[g]and apply the operation. - Calculate the permuted destination coordinates and write the final result to
F_[g].
- Load the bias
- Calculating the base memory addresses for
This approach maximizes parallelism at two levels: the coarse-grained parallelism across the G groups, and the fine-grained data parallelism within each individual GEMM operation.
Source Code Organization
grouped_gemm_bias_e_permute_xdl.cpp: The main example file. It demonstrates the complex setup for a grouped problem, defining theGsets of input tensors and the permutation. It then instantiates theDeviceGroupedGemmBiasEPermuteoperation.../../include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_bias_e_permute_impl.hpp: The high-level device interface for this specific fused operation. It takes arrays of tensor descriptors, one for each group.- The underlying grid-wise kernel contains the logic to map thread blocks to groups and then execute the full fused GEMM pipeline for the assigned group.
Build and Run
Prerequisites
Ensure the Composable Kernel library is built and installed.
cd /path/to/composable_kernel/build
make -j install
Build the Example
cd /path/to/composable_kernel/example/28_grouped_gemm_bias_e_permute
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
Run the Example
# Run the example with default settings
./grouped_gemm_bias_e_permute_xdl
# Run with verification, data initialization, and timing
./grouped_gemm_bias_e_permute_xdl 1 2 1
Applications
This highly specialized kernel is valuable for optimizing specific patterns in modern neural networks:
- Multi-Head Attention (MHA): The computation for each head in MHA is independent. The entire MHA block can be viewed as a Grouped GEMM where the number of groups
Gis the number of attention heads. If the Q, K, or V projections involve fusions with bias, other elementwise ops, and permutations to prepare the data for the batched GEMM, this kernel could potentially fuse a large part of that logic. - Depthwise Separable Convolutions: The depthwise part of this convolution is a Grouped GEMM with
Gequal to the number of channels. If this is followed by a fused activation function (e.g., a gated activation) and a permutation, this kernel could be a perfect match. - Mixture-of-Experts (MoE) Models: In MoE layers, an input is routed to one of several "expert" sub-networks. If these experts have identical structure, their execution can be formulated as a Grouped GEMM, where
Gis the number of experts. Any fusions within the expert network could be captured by this kernel.
This example showcases the extreme composability of the library, allowing for the creation of highly tailored, high-performance kernels for complex, group-parallel computational graphs.