Files
composable_kernel/example/29_batched_gemm_bias_e_permute
Erwin Terpstra fe40a5d139 Implement batched gemm bias permute for RDNA4 (#3534)
* feat: test setup for batched contraction (aka batched gemm multiple d e permute)

* wip: device struct for WMMA batched contraction multiple d based on new gridwise op

* feat: working batched contraction on RDNA, non-naive tensor descriptors for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases

* fix: failure to resolve template parameters when calling new function overload

* fix: passing reference type as parameter instead of underlying types

* fix: merge error caused duplicate definitions

* fix: make sure constness of template and parameters types match

* fix: don't compile batched contraction test on unsupported architectures

* feat: add example for new wmma implementation, and consolidate example code between platforms

* style: return inline instead of with branch

* chore: add extra assert on vector memory access sizes

* chore: clean up some unused variables

* fix: correct tail number calculation, added small cases and extra instances to the test

* fix: properly support wave transfer by generating correct grid descriptors dependent on the transfer method
2026-01-17 08:30:27 +01:00
..

Batched GEMM with Bias, Elementwise Operation, and Permutation

This example demonstrates a Batched GEMM where each individual GEMM operation is fused with a bias addition, a second elementwise operation, and a final permutation of the output. This kernel is designed to accelerate layers that have a batch-parallel structure, such as the dense layers in a Transformer's feed-forward network, when they are part of a larger fused computational graph.

Mathematical Formulation

This operation performs B independent fused GEMM operations in parallel, where B is the batch count. For each batch item b from 0 to B-1:

  1. GEMM Stage: A standard matrix multiplication. C_{temp1[b]} = A_{[b]} \times B_{[b]}

  2. Bias Addition Stage: A bias vector D_[b] is broadcast and added. C_{temp2[b]} = C_{temp1[b]} + D_{[b]}

  3. Elementwise Stage: A second elementwise operation is performed with tensor E_[b]. C_{temp3[b]} = C_{temp2[b]} \odot E_{[b]}

  4. Permutation Stage: The final result for the batch item is permuted. F_{[b]} = \text{permute}(C_{temp3[b]})

All four stages for all B batch items are executed within a single kernel launch. The intermediate results are kept in registers and never written to global memory.

Distinction from Grouped Version:

  • In this Batched version, all B problems are uniform. They share the same dimensions (M, N, K), layouts, and permutations. The input/output tensors are accessed with a constant batch stride.
  • In the Grouped version (28_grouped_gemm_bias_e_permute), each of the G problems can have different dimensions, layouts, and strides, offering more flexibility.

Algorithmic Strategy: Batch-Parallel GEMM with Fused Epilogue

The implementation combines the scheduling strategy of Batched GEMM with the multi-stage fused epilogue.

  1. Batch Scheduling: The B independent problems are distributed across the GPU's thread blocks. The grid-wise kernel is designed such that each thread block is assigned to compute one of the B fused operations.

  2. Fused GEMM Execution: Once a thread block is assigned a batch item b, it executes a complete fused GEMM for that item's specific data. This involves:

    • Calculating the base memory addresses for A_{[b]}, B_{[b]}, D_{[b]}, E_{[b]}, and F_{[b]} using the batch index and the constant batch stride.
    • Executing a standard tiled GEMM for A_{[b]} \times B_{[b]}, accumulating the result in registers.
    • Executing the fused epilogue:
      • Load the bias D_[b] and add it.
      • Load the elementwise tensor E_[b] and apply the operation.
      • Calculate the permuted destination coordinates and write the final result to F_{[b].

This approach is extremely efficient when the batch size B is large enough to saturate the GPU's parallelism.

Source Code Organization

Build and Run

Prerequisites

Ensure the Composable Kernel library is built and installed.

cd /path/to/composable_kernel/build
make -j install

Build the Example

cd /path/to/composable_kernel/example/29_batched_gemm_bias_e_permute
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

# Run the example with default settings
./batched_gemm_bias_e_permute_xdl

# Run with verification, data initialization, and timing
./batched_gemm_bias_e_permute_xdl 1 2 1

Applications

This kernel is ideal for optimizing the feed-forward network (FFN) block in a Transformer, especially when layout transformations are needed between layers.

A typical Transformer FFN block is: FFN(X) = Linear_2(ReLU(Linear_1(X)))

  • Linear_1 is a GEMM.
  • ReLU is an elementwise activation.
  • Linear_2 is another GEMM.

Sometimes, for performance reasons (e.g., to align with a subsequent layer's expected input layout), the output of the FFN needs to be permuted. This kernel could fuse the Linear_2 GEMM with its bias, a subsequent elementwise operation (if any), and the final permutation, all while operating on a batch of input sequences. This avoids multiple kernel launches and saves significant memory bandwidth, leading to faster model execution.