Files
composable_kernel/example/15_grouped_gemm
assistant-librarian[bot] 9c8d3a39ac 173 implement device grouped gemm fixed nk for rdna4 (#4299)
## Proposed changes

This PR adds an RDNA4 implementation of the device_grouped_gemm_fixed_nk
instance library using for WMMA.

The implementation is based on the existing
DeviceGroupedGemm_Xdl_Fixed_NK design and reuses the same high-level
structure, but replaces the XDL kernel with a WMMA-based one. It uses
the GridwiseGemm_wmma_cshuffle_v3 kernel.

At this stage, the focus is functional correctness and compatibility,
not performance tuning.

## Technical Details

- Device struct for grouped gemm fixed NK
- Example code for the WMMA version
- Unit tests for both new wmma implementation and the reference XDL code
(previously missing)
- Generic ck profiler interface with the purpose of calling unit tests.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [x] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [x] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered



---
🔁 Imported from
[ROCm/composable_kernel#3668](https://github.com/ROCm/composable_kernel/pull/3668)
🧑‍💻 Originally authored by @bidlekm

---------

Co-authored-by: Marton Bidlek <marton.bidlek@streamhpc.com>
Co-authored-by: Erwin Terpstra <erwin.terpstra@streamhpc.com>
Co-authored-by: bidlekm <bidlekmarton@gmail.com>
Co-authored-by: assistant-librarian[bot] <assistant-librarian[bot]@users.noreply.github.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
2026-02-19 09:13:05 +01:00
..

Grouped GEMM

Theory

This example demonstrates grouped GEMM: performing multiple independent GEMM operations (with potentially different shapes) in a single kernel launch. Grouped GEMM is used in transformer models (e.g., multi-head attention), mixture-of-experts, and other architectures requiring heterogeneous batched matrix multiplications.

Mathematical Formulation: For G groups, each with its own A_g, B_g, C_g:


C_g = A_g \times B_g \quad \text{for} \quad g = 1, 2, ..., G
  • A_g: [M_g, K_g] input matrix for group g
  • B_g: [K_g, N_g] weight matrix for group g
  • C_g: [M_g, N_g] output matrix for group g

Algorithmic Background:

  • Each group can have different matrix sizes and strides.
  • The kernel launches a grid covering all groups, with each block assigned to a group.
  • Useful for variable-length sequences, multi-head attention, and expert routing.

How to Run

Prerequisites

Please follow the instructions in the main Build Guide section as a prerequisite to building and running this example.

Build and run

cd composable_kernel/example/15_grouped_gemm
mkdir build && cd build
cmake -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc ..
make -j

Run example_grouped_gemm_xdl

#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
./bin/example_grouped_gemm_xdl_fp16 0 1 5

Source Code Structure

Directory Layout

example/15_grouped_gemm/
├── grouped_gemm_xdl.cpp         # Main example: sets up, runs, and verifies grouped GEMM
include/ck/tensor_operation/gpu/device/
│   └── device_grouped_gemm_xdl.hpp       # Device-level grouped GEMM API
include/ck/tensor_operation/gpu/grid/
│   └── gridwise_grouped_gemm_xdl.hpp     # Grid-level grouped GEMM kernel

Key Classes and Functions

  • DeviceGroupedGemmXdl (in device_grouped_gemm_xdl.hpp):
    Device API for grouped GEMM.
  • gridwise_grouped_gemm_xdl (in gridwise_grouped_gemm_xdl.hpp):
    Implements the tiled/blocking grouped GEMM kernel.

This example demonstrates how Composable Kernel supports efficient heterogeneous batched matrix multiplication for advanced AI/ML workloads.