Files
composable_kernel/example/29_batched_gemm_bias_e_permute
John Shumway ad57f6ef0b [CK_BUILDER] Put global CK functions in an the CK namespace (#3232)
* Wrap ck host utitlies in CK namespace.

The CK and CK-Tile source code bases are incompatible because CK is not properly using namespaces everywhere. In particular, we need to put hip_check_error in the ck namespace.

Move all functions in include/ck_/host_utility that were in global namespace into the ck namespace.

There may be additional namespace problems like this, and it's possible we'll have namespace clashes. But it is good design to properly guard our to code bases (CK and CKTile) so that they can both coexist. Moreover, estabilishing this compatiblity is essential if we are going to allow the builder to instantiate  kernels from either template library.

* Add using declarations to test code.

After moving some of the untils into the ck namespace, most examples and a few tests had to be updated to recognize the new namespace declarations. We add using declarations to individual compute units for functions that were previously in the global namespace.

* Add using declarations to client examples.
2025-11-19 11:23:02 +01:00
..

Batched GEMM with Bias, Elementwise Operation, and Permutation

This example demonstrates a Batched GEMM where each individual GEMM operation is fused with a bias addition, a second elementwise operation, and a final permutation of the output. This kernel is designed to accelerate layers that have a batch-parallel structure, such as the dense layers in a Transformer's feed-forward network, when they are part of a larger fused computational graph.

Mathematical Formulation

This operation performs B independent fused GEMM operations in parallel, where B is the batch count. For each batch item b from 0 to B-1:

  1. GEMM Stage: A standard matrix multiplication. C_{temp1[b]} = A_{[b]} \times B_{[b]}

  2. Bias Addition Stage: A bias vector D_[b] is broadcast and added. C_{temp2[b]} = C_{temp1[b]} + D_{[b]}

  3. Elementwise Stage: A second elementwise operation is performed with tensor E_[b]. C_{temp3[b]} = C_{temp2[b]} \odot E_{[b]}

  4. Permutation Stage: The final result for the batch item is permuted. F_{[b]} = \text{permute}(C_{temp3[b]})

All four stages for all B batch items are executed within a single kernel launch. The intermediate results are kept in registers and never written to global memory.

Distinction from Grouped Version:

  • In this Batched version, all B problems are uniform. They share the same dimensions (M, N, K), layouts, and permutations. The input/output tensors are accessed with a constant batch stride.
  • In the Grouped version (28_grouped_gemm_bias_e_permute), each of the G problems can have different dimensions, layouts, and strides, offering more flexibility.

Algorithmic Strategy: Batch-Parallel GEMM with Fused Epilogue

The implementation combines the scheduling strategy of Batched GEMM with the multi-stage fused epilogue.

  1. Batch Scheduling: The B independent problems are distributed across the GPU's thread blocks. The grid-wise kernel is designed such that each thread block is assigned to compute one of the B fused operations.

  2. Fused GEMM Execution: Once a thread block is assigned a batch item b, it executes a complete fused GEMM for that item's specific data. This involves:

    • Calculating the base memory addresses for A_{[b]}, B_{[b]}, D_{[b]}, E_{[b]}, and F_{[b]} using the batch index and the constant batch stride.
    • Executing a standard tiled GEMM for A_{[b]} \times B_{[b]}, accumulating the result in registers.
    • Executing the fused epilogue:
      • Load the bias D_[b] and add it.
      • Load the elementwise tensor E_[b] and apply the operation.
      • Calculate the permuted destination coordinates and write the final result to F_{[b].

This approach is extremely efficient when the batch size B is large enough to saturate the GPU's parallelism.

Source Code Organization

Build and Run

Prerequisites

Ensure the Composable Kernel library is built and installed.

cd /path/to/composable_kernel/build
make -j install

Build the Example

cd /path/to/composable_kernel/example/29_batched_gemm_bias_e_permute
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

# Run the example with default settings
./batched_gemm_bias_e_permute_xdl

# Run with verification, data initialization, and timing
./batched_gemm_bias_e_permute_xdl 1 2 1

Applications

This kernel is ideal for optimizing the feed-forward network (FFN) block in a Transformer, especially when layout transformations are needed between layers.

A typical Transformer FFN block is: FFN(X) = Linear_2(ReLU(Linear_1(X)))

  • Linear_1 is a GEMM.
  • ReLU is an elementwise activation.
  • Linear_2 is another GEMM.

Sometimes, for performance reasons (e.g., to align with a subsequent layer's expected input layout), the output of the FFN needs to be permuted. This kernel could fuse the Linear_2 GEMM with its bias, a subsequent elementwise operation (if any), and the final permutation, all while operating on a batch of input sequences. This avoids multiple kernel launches and saves significant memory bandwidth, leading to faster model execution.