[rocm-libraries] ROCm/rocm-libraries#4996 (commit 0a47fbe)

[CK TILE ENGINE] Add grouped_gemm operator to Tile Engine
 (gfx942/gfx950) (#4996)

## Motivation

The grouped_gemm CK Tile kernel exists (e.g.,
`example/17_grouped_gemm/`) but has no Tile Engine wrapper. Grouped GEMM
handles multiple independent GEMM problems with varying M/N/K dimensions
in a single kernel launch. This PR adds the Tile Engine infrastructure
for automated kernel generation, benchmarking, and profiling of grouped
GEMM kernels.

Jira: AICK-809

## Technical Details

- Created Tile Engine wrapper under `tile_engine/ops/gemm/grouped_gemm/`
following the `gemm_universal` template
- Files added: `CMakeLists.txt`, `grouped_gemm_common.hpp`,
`grouped_gemm_benchmark.hpp`, `grouped_gemm_profiler.hpp`,
`grouped_gemm_benchmark.py`, `grouped_gemm_benchmark_single.cpp`,
`grouped_gemm_instance_builder.py`, `configs/`
- Supported datatypes: fp16, fp8, bf16, bf8
- Supported layouts: rcr, rrr, ccr, crr
- Target GPUs: gfx942, gfx950
- CK Tile kernel: `ck_tile::GroupedGemmKernel` from
`include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp`
- Instance builder extends `GemmKernelBuilder` base class
- Registered in `tile_engine/ops/gemm/CMakeLists.txt`
- Updated Jenkinsfile to build and benchmark grouped_gemm targets in CI
- Benchmark infrastructure includes JSON output, CSV export, and
verification support

## Test Plan

- CMake configure succeeds for grouped_gemm targets
- Kernel instance builder generates valid kernel headers for all
(datatype, layout) combinations
- At least one kernel binary compiles and runs per datatype/layout
combination
- Correctness passes with `--verify 1` on gfx942/gfx950

## Test Result

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2026-03-10 23:59:26 +00:00
committed by assistant-librarian[bot]
parent 9f47b8a63d
commit c85c272c39
13 changed files with 2582 additions and 24 deletions

View File

@@ -168,6 +168,8 @@ class GemmKernelBuilder:
default_pipeline = "compv4"
elif self.kernel_name_prefix == "gemm_preshuffle":
default_pipeline = "preshufflev2"
elif self.kernel_name_prefix == "grouped_gemm":
default_pipeline = "compv4"
configs = []
for tile_m in tile_m_values:
@@ -335,7 +337,11 @@ class GemmKernelBuilder:
kernel_name += f"_{tile_str}"
if self.kernel_name_prefix in ["gemm_universal", "gemm_multi_d"]:
if self.kernel_name_prefix in [
"gemm_universal",
"gemm_multi_d",
"grouped_gemm",
]:
# Map pipeline names to the correct pipeline implementation
pipeline_impl_map = {
"mem": "ck_tile::GemmPipelineAgBgCrMem",
@@ -410,6 +416,11 @@ class GemmKernelBuilder:
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
"""
if self.kernel_name_prefix == "grouped_gemm":
instance_code += """#include <vector>
#include <hip/hip_runtime.h>
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
"""
return instance_code
@@ -425,10 +436,11 @@ class GemmKernelBuilder:
# Assign layouts based on self.layout
if self.kernel_name_prefix == "gemm_multi_d":
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
elif (
self.kernel_name_prefix == "gemm_universal"
or self.kernel_name_prefix == "gemm_preshuffle"
):
elif self.kernel_name_prefix in [
"gemm_universal",
"gemm_preshuffle",
"grouped_gemm",
]:
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
instance_code = f"""
@@ -502,8 +514,12 @@ struct SelectedKernel {{
static constexpr bool TransposeC = false;
static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};"""
if self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
instance_code += f"""
if self.kernel_name_prefix in [
"gemm_universal",
"gemm_preshuffle",
"grouped_gemm",
]:
instance_code += f"""
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
static constexpr bool UseStructuredSparsity = false;
static constexpr ck_tile::index_t NumWaveGroups = 1;"""
@@ -528,9 +544,13 @@ struct SelectedKernel {{
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;"""
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
elif self.kernel_name_prefix in [
"gemm_universal",
"gemm_preshuffle",
"grouped_gemm",
]:
instance_code = """
// Tile shape
using TileShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileM, TileN, TileK>,
@@ -604,6 +624,13 @@ struct SelectedKernel {{
// Launch function
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {"""
elif self.kernel_name_prefix == "grouped_gemm":
instance_code = """
// Launch function
static float launch(const std::vector<ck_tile::GroupedGemmHostArgs<>>& gemm_descs,
const ck_tile::stream_config& stream,
void* kargs_ptr) {"""
# Scheduler initialization
if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]:
@@ -644,12 +671,12 @@ struct SelectedKernel {{
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
# Scheduler initialization
if self.kernel_name_prefix in ["gemm_universal"]:
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += f"""
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""
# UniversalGemmProblem
if self.kernel_name_prefix in ["gemm_universal"]:
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += """
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
@@ -664,7 +691,7 @@ struct SelectedKernel {{
scheduler>;"""
# GemmPipeline
if self.kernel_name_prefix in ["gemm_universal"]:
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += f"""
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
@@ -711,13 +738,13 @@ struct SelectedKernel {{
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
instance_code += f"""
// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
@@ -725,7 +752,7 @@ struct SelectedKernel {{
// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
@@ -733,13 +760,55 @@ struct SelectedKernel {{
<< std::endl;
}}"""
instance_code += f"""
instance_code += f"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}}
}};
"""
elif self.kernel_name_prefix == "grouped_gemm":
instance_code += f"""
// Kernel type
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping grouped gemm!");
}}
// Get grid and block sizes
const dim3 grids = {"Kernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "dim3(kargs.empty() ? 0 : kargs.back().block_end, 1, 1)"};
const dim3 blocks = Kernel::BlockSize();
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
hipMemcpyHostToDevice,
stream.stream_id_));
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
kargs.size()));
return ave_time;
}}
}};
@@ -753,14 +822,14 @@ struct SelectedKernel {{
"""
if epilogue == "cshuffle":
if self.kernel_name_prefix == "gemm_universal":
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += self.populate_cshuffle_gemm_universal()
elif self.kernel_name_prefix == "gemm_multi_d":
instance_code += self.populate_cshuffle_gemm_multi_d()
elif self.kernel_name_prefix == "gemm_preshuffle":
instance_code += self.populate_cshuffle_gemm_preshuffle()
else: # default epilogue
if self.kernel_name_prefix == "gemm_universal":
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += self.populate_default_gemm_universal()
elif self.kernel_name_prefix == "gemm_multi_d":
instance_code += self.populate_default_gemm_multi_d()