Files
Erwin Terpstra 46f1d740f0 Add grouped gemm instances for RDNA4 (#3237)
* wip: grouped_gemm implementation based on wmma kernel + example for fp16

* chore: clean up grouped_gem_wmma_splitk_fp16 example

* chore: add cmake options to fully disable XDL or WMMA kernels

* feat: add tests for grouped gemma wmma instances for f16 and bf16 (all layouts)

* chore: add grouped gemm wmma bf16 example

* refactor: reuse more code between instance factory functions

* chore: turn test failure if not all batch sizes are supported into a warning

* chore: made failing of test on unsupported instances conditional to not break old tests

* chore: add log message to failure case where AK1/BK1/KBatch is too high for K value

* fix: issue with new overloads of GridwiseGemm_wmma_cshuffle_v3::Run()

* fix: stray comma after parameter list

* fix: compilation issues on RDNA3 and tests failing due to unsupported problems still being ran

* chore: update copyright in header comments

* nit: minor feebdack

* refactor: unified XDL / wma tests

* fix: properly disable FP8 instances when ONLY targeting gfx11

* refactor: add v3 suffix to grouped_gemm device struct name

* fix: small typos in example code

* fix: fully exclude xdl/wmma instances when using the corresponding cmake flags

* chore: remove unused destructor and added pipeline support checks to remove unnecessary paths

* fix: make sure to not add instance library to group if library was skipped

* fix: make sure xdl grouped gemm doesnt fail the new test

* fix: explicitly exclude test if no xdl/wmma support, as pattern matching fails in this case

* fix: examples not working since dependent types and functions were moved to ck namespace in develop

* fix: tests failing when compiling for just gfx11 due to trying to run unsupported instances

* chore: replace/add copyright headers with new format
2025-12-01 15:32:10 -08:00
..

Grouped GEMM

Theory

This example demonstrates grouped GEMM: performing multiple independent GEMM operations (with potentially different shapes) in a single kernel launch. Grouped GEMM is used in transformer models (e.g., multi-head attention), mixture-of-experts, and other architectures requiring heterogeneous batched matrix multiplications.

Mathematical Formulation: For G groups, each with its own A_g, B_g, C_g:


C_g = A_g \times B_g \quad \text{for} \quad g = 1, 2, ..., G
  • A_g: [M_g, K_g] input matrix for group g
  • B_g: [K_g, N_g] weight matrix for group g
  • C_g: [M_g, N_g] output matrix for group g

Algorithmic Background:

  • Each group can have different matrix sizes and strides.
  • The kernel launches a grid covering all groups, with each block assigned to a group.
  • Useful for variable-length sequences, multi-head attention, and expert routing.

How to Run

Prerequisites

Please follow the instructions in the main Build Guide section as a prerequisite to building and running this example.

Build and run

cd composable_kernel/example/15_grouped_gemm
mkdir build && cd build
cmake -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc ..
make -j

Run example_grouped_gemm_xdl

#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
./bin/example_grouped_gemm_xdl_fp16 0 1 5

Source Code Structure

Directory Layout

example/15_grouped_gemm/
├── grouped_gemm_xdl.cpp         # Main example: sets up, runs, and verifies grouped GEMM
include/ck/tensor_operation/gpu/device/
│   └── device_grouped_gemm_xdl.hpp       # Device-level grouped GEMM API
include/ck/tensor_operation/gpu/grid/
│   └── gridwise_grouped_gemm_xdl.hpp     # Grid-level grouped GEMM kernel

Key Classes and Functions

  • DeviceGroupedGemmXdl (in device_grouped_gemm_xdl.hpp):
    Device API for grouped GEMM.
  • gridwise_grouped_gemm_xdl (in gridwise_grouped_gemm_xdl.hpp):
    Implements the tiled/blocking grouped GEMM kernel.

This example demonstrates how Composable Kernel supports efficient heterogeneous batched matrix multiplication for advanced AI/ML workloads.