Files
composable_kernel/example/ck_tile/17_grouped_gemm
kyle-256 4eb415829e [CK_TILE] Implement Row/Col quant grouped gemm (#2786)
* Add cshuffle epilogue test

* add the poc implementation to the epilogue and tests

* refactor cshuffle epilogue

* WIP: adding tensor/tile usage to scale_tile

* fix usage of tile_elementwise_inout

* add gemm_quant_kernel for generalizing gemm quant kernel

* Add problem specific to different quants, add QuantType to Traits

* Add quant_type to quant_kernel template parameters

* Create aq/bq_block_windows and views depending on QuantType

* Use tile windows as inputs in cshuffle epilogue

* Fix some issues in epilogue

* initial new example code for new general gemm quant kernel test

* Fix issues in kernel

* Add verification check for rowcol Quantmode

* use AccDataType instead of AQ in pipeline

* fix aquant preshuffle

* fix formatting

* some cleanup

* remove gemm_aquant_basic.cpp

* remove gemm_aquant_kernel.hpp

* fix tests for the renamed quant kernel

* fix formatting

* clean example files

* fix some merge conflicts

* fix preshufflequant rename issue

* updating

* fix some templates after merging with develop

* fix test preshuffle parameter

* fix formatting

* updating kernels

* change update user

* test username

* update quant_grouped_gemm example

* update example

* Unify bquant kernel to the common quant kernel

* remove bquant kernel also from common header

* fix formatting

* clean up commented code

* update grouped_gemm_quant example

* fix formatting config hpp

* fix merge mistake

* Non-const for movable windows

* fix formatting

* update tileloop pipleline

* Fix grammar in README

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Remove #include<bit> and clean up example

* fix strides

* Add some descriptions for move_windows

* fix tensor print bug

* update quant_grouped_gemm example

* remove useless code

* cleanup code

* clean up code & format code

* fix compile & running bug in grouped_gemm example

---------

Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: liyingli <liyingli@amd.com>
Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
2025-09-08 10:25:57 -07:00
..

Grouped Gemm

Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation.

Quick Tour for New Users

The Grouped GEMM operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads.

Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.

Key Arguments

The example takes several arguments including group_count, repeat, and warmup:

  • group_count: the number of GEMM operations in the group
  • repeat: the number of times to repeat the kernel for benchmarking
  • warmup: the number of iterations before the actual kernel run time measure
// Example
const int group_count = arg_parser.get_int("group_count");
const int repeat      = arg_parser.get_int("repeat");
const int warmup      = arg_parser.get_int("warmup");

In the next step, the input parameters Ms, Ns, Ks, as well as the corresponding stride_As, stride_Bs, and stride_Cs are either provided from the comand line or generated by default. Since one or more input data sets are expected for A and B, each parameter is stored in a std::vector. The size of the vector is defined by group_count.

// Example
std::vector<ck_tile::index_t> Ms        = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns        = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks        = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");

Where:

  • Ms is the M dimension of each GEMM.
  • Ns is the N dimension of each GEMM.
  • Ks is the K dimension of each GEMM.
  • stride_As is the stride values for matrix A.
  • stride_Bs is the stride values for matrix B.
  • stride_Cs is the stride values for matrix C.

HostTensor and Device Memory Buffers (for CPU and GPU)

Each parameter Ms, Ns, Ks, stride_As, stride_Bs and stride_Cs contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations. The next step is to properly load the input values. For each input matrix, A and B, and for each output matrix, C, you need to create both HostTensor and DeviceMemory, where:

  • HostTensor represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation.
  • DeviceMemory represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation.

HostTensor Buffers (for CPU)

In the first step, create HostTensor for A, B, C. HostTensor allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors:

// Example
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;

Where:

  • a_m_k_tensors is the vector of HostTensor objects for matrix A (with dimensions M × K). Each tensor stores the data for single GEMM operation.
  • b_k_n_tensors is the vector of HostTensor objects for matrix B (with dimensions K × N).
  • c_m_n_tensors is the vector of HostTensor objects for matrix C (the output matrix with dimensions M × N).

The std::vector container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to group_count.

Device Memory Buffers (for GPU)

Now it's time to allocate memory on the device (GPU) and transfer the data from HostTensor to DeviceMemory for actual computation..

// Example
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;

Where:

  • a_m_k_dev_buf is the buffer used for storing matrix A on the GPU.
  • b_k_n_dev_buf is the buffer used for storing matrix B on the GPU.
  • c_m_n_dev_buf is the buffer used for storing the result matrix C on the GPU.

Prepare data

In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., FillUniformDistribution), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor.

Use get_default_stride to get the strides for A, B, and C. get_default_stride is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params row, col, stride and bool_constant<is_row_major>. If the stride is explicitly provided (stride != 0), the stride is returned as-is. If the stride is not provided (stride == 0), the function computes the default stride. For the Row-major order (is_row_major == true), the stride is set to the number of columns (col). For the column-major order (is_row_major == false), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order.

// Example, API
template <bool is_row_major>
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant<is_row_major>) {
  // code
}

Where:

  • is_row_major is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false).
  • row is the number of rows in the matrix.
  • col is the number of columns in the matrix.
  • stride is the current stride (the distance between consecutive elements in memory).
  • bool_constant<is_row_major> is a tag type that helps in differentiating behavior at compile-time.

Next host descriptors for each of the input tensors, A, B, and C are created. Use the f_host_tensor_descriptor function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout.

// Example for tensor A
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)))

After creating the host_tensors, create deviceMem for each tensor A, B, and C, and then transfer the data to the device. The get_element_space_size_in_bytes() function is used to get the buffer size in bytes. Use ToDevice() to transfer data from the host to the device. The data that was previously generated (a_m_k_tensors[i].data()) is passed as a parameter to ToDevice().

The final step before running the GEMM operation is to retrieve the pointers to the buffers of A, B, and C stored on the device using ->GetDeviceBuffer() and pack them into a shared container. For example: gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}), where gemm_descs is std::vector<grouped_gemm_kargs> gemm_descs (Code). The container should include values such as:

struct GroupedGemmHostArgs
{
    const void* a_ptr;
    const void* b_ptr;
    void* c_ptr;
    index_t M;
    index_t N;
    index_t K;
    index_t stride_A;
    index_t stride_B;
    index_t stride_C;
};

The data prepared in this way can be passed to the invoke_gemm function. This is a templated function that also takes three template parameters: ALayout, BLayout, and CLayout:

// Example, API
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
float invoke_gemm(int n_warmup,
                  int n_repeat,
                  int group_count,
                  const std::vector<grouped_gemm_kargs>& args)

invoke_gemm returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace:

// Example
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));

Advanced Features: Preshuffle and Persistence

The grouped GEMM examples include two advanced optimization features:

Weight Preshuffle

Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.

  • Implementation: Available in grouped_gemm_preshuffle.cpp
  • Configuration: Uses GemmConfigPreshuffleDecode template configuration
  • Constraints: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
  • Benefits: Improved memory efficiency and reduced data movement

Persistence Mode

Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.

  • Template Parameter: Controlled by the Persistent boolean template parameter in invoke_gemm
  • Usage: invoke_gemm<ALayout, BLayout, CLayout, true> enables persistence
  • Benefits: Reduced kernel launch overhead, better resource utilization for small matrix sizes

Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.

Finally the arguments are passed to group_gemm and the kernel is launched.

// API
template <typename ALayout, typename BLayout, typename CLayout>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
                   const ck_tile::stream_config& s,
                   void* kargs_ptr)

All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched.

Build

# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh  ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j

This will result in an executable build/bin/tile_example_grouped_gemm

example

args:
 -Ms          M dimensions - (Default: empty).
 -Ns          N dimensions - (Default: empty).
 -Ks          K dimensions - (Default: empty).
 -stride_As   Tensor A strides - (Default: empty).
 -stride_Bs   Tensor B strides - (Default: empty).
 -stride_Cs   Tensor C strides - (Default: empty).
 -a_layout    A tensor data layout - (Default: Row).
 -b_layout    B tensor data layout - (Default: Col).
 -c_layout    C tensor data layout - (Default: Row).
 -prec        data type. fp16/fp8 - (Default: fp16).
 -validate    0. No validation, 1. Validation on CPU. (Default: 1).
 -warmup      Number of iterations before benchmark the kernel. (Default: 10).
 -repeat      Number of iterations to benchmark the kernel. (Default: 100).
 -group_count Group count. (Default: 16).
 -kbatch      kbatch for SplitK (Default: 1).
 -json        0: No Json, 1: Dump Results in Json format (Default: 0).
 -jsonfile    json file name to dump results (Default: grouped_gemm.json).

If any of Ms, Ns, Ks, stride_As, stride_Bs, or stride_Cs are missing or their sizes don't match group_count, the example generates defaults per group index i (0-based):

M[i] = 256 + 256 * i
N[i] = 256 + 512 * i
K[i] = 512 + 384 * i

stride_A[i] = K[i]
stride_B[i] = K[i]
stride_C[i] = N[i]