mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4996 (commit 0a47fbe)
[CK TILE ENGINE] Add grouped_gemm operator to Tile Engine (gfx942/gfx950) (#4996) ## Motivation The grouped_gemm CK Tile kernel exists (e.g., `example/17_grouped_gemm/`) but has no Tile Engine wrapper. Grouped GEMM handles multiple independent GEMM problems with varying M/N/K dimensions in a single kernel launch. This PR adds the Tile Engine infrastructure for automated kernel generation, benchmarking, and profiling of grouped GEMM kernels. Jira: AICK-809 ## Technical Details - Created Tile Engine wrapper under `tile_engine/ops/gemm/grouped_gemm/` following the `gemm_universal` template - Files added: `CMakeLists.txt`, `grouped_gemm_common.hpp`, `grouped_gemm_benchmark.hpp`, `grouped_gemm_profiler.hpp`, `grouped_gemm_benchmark.py`, `grouped_gemm_benchmark_single.cpp`, `grouped_gemm_instance_builder.py`, `configs/` - Supported datatypes: fp16, fp8, bf16, bf8 - Supported layouts: rcr, rrr, ccr, crr - Target GPUs: gfx942, gfx950 - CK Tile kernel: `ck_tile::GroupedGemmKernel` from `include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp` - Instance builder extends `GemmKernelBuilder` base class - Registered in `tile_engine/ops/gemm/CMakeLists.txt` - Updated Jenkinsfile to build and benchmark grouped_gemm targets in CI - Benchmark infrastructure includes JSON output, CSV export, and verification support ## Test Plan - CMake configure succeeds for grouped_gemm targets - Kernel instance builder generates valid kernel headers for all (datatype, layout) combinations - At least one kernel binary compiles and runs per datatype/layout combination - Correctness passes with `--verify 1` on gfx942/gfx950 ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
9f47b8a63d
commit
c85c272c39
@@ -168,6 +168,8 @@ class GemmKernelBuilder:
|
||||
default_pipeline = "compv4"
|
||||
elif self.kernel_name_prefix == "gemm_preshuffle":
|
||||
default_pipeline = "preshufflev2"
|
||||
elif self.kernel_name_prefix == "grouped_gemm":
|
||||
default_pipeline = "compv4"
|
||||
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
@@ -335,7 +337,11 @@ class GemmKernelBuilder:
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
if self.kernel_name_prefix in ["gemm_universal", "gemm_multi_d"]:
|
||||
if self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_multi_d",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
# Map pipeline names to the correct pipeline implementation
|
||||
pipeline_impl_map = {
|
||||
"mem": "ck_tile::GemmPipelineAgBgCrMem",
|
||||
@@ -410,6 +416,11 @@ class GemmKernelBuilder:
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
"""
|
||||
if self.kernel_name_prefix == "grouped_gemm":
|
||||
instance_code += """#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
"""
|
||||
return instance_code
|
||||
|
||||
@@ -425,10 +436,11 @@ class GemmKernelBuilder:
|
||||
# Assign layouts based on self.layout
|
||||
if self.kernel_name_prefix == "gemm_multi_d":
|
||||
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
|
||||
elif (
|
||||
self.kernel_name_prefix == "gemm_universal"
|
||||
or self.kernel_name_prefix == "gemm_preshuffle"
|
||||
):
|
||||
elif self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_preshuffle",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
|
||||
|
||||
instance_code = f"""
|
||||
@@ -502,8 +514,12 @@ struct SelectedKernel {{
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};"""
|
||||
|
||||
if self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
instance_code += f"""
|
||||
if self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_preshuffle",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
instance_code += f"""
|
||||
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;"""
|
||||
@@ -528,9 +544,13 @@ struct SelectedKernel {{
|
||||
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;"""
|
||||
|
||||
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
elif self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_preshuffle",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
instance_code = """
|
||||
|
||||
|
||||
// Tile shape
|
||||
using TileShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<TileM, TileN, TileK>,
|
||||
@@ -604,6 +624,13 @@ struct SelectedKernel {{
|
||||
|
||||
// Launch function
|
||||
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {"""
|
||||
elif self.kernel_name_prefix == "grouped_gemm":
|
||||
instance_code = """
|
||||
|
||||
// Launch function
|
||||
static float launch(const std::vector<ck_tile::GroupedGemmHostArgs<>>& gemm_descs,
|
||||
const ck_tile::stream_config& stream,
|
||||
void* kargs_ptr) {"""
|
||||
|
||||
# Scheduler initialization
|
||||
if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]:
|
||||
@@ -644,12 +671,12 @@ struct SelectedKernel {{
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
|
||||
|
||||
# Scheduler initialization
|
||||
if self.kernel_name_prefix in ["gemm_universal"]:
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += f"""
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""
|
||||
|
||||
# UniversalGemmProblem
|
||||
if self.kernel_name_prefix in ["gemm_universal"]:
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += """
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
@@ -664,7 +691,7 @@ struct SelectedKernel {{
|
||||
scheduler>;"""
|
||||
|
||||
# GemmPipeline
|
||||
if self.kernel_name_prefix in ["gemm_universal"]:
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += f"""
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
|
||||
@@ -711,13 +738,13 @@ struct SelectedKernel {{
|
||||
|
||||
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
instance_code += f"""
|
||||
|
||||
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
@@ -725,7 +752,7 @@ struct SelectedKernel {{
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
||||
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
@@ -733,13 +760,55 @@ struct SelectedKernel {{
|
||||
<< std::endl;
|
||||
}}"""
|
||||
|
||||
instance_code += f"""
|
||||
instance_code += f"""
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
elif self.kernel_name_prefix == "grouped_gemm":
|
||||
instance_code += f"""
|
||||
|
||||
// Kernel type
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping grouped gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"Kernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "dim3(kargs.empty() ? 0 : kargs.back().block_end, 1, 1)"};
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
kargs.size()));
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
@@ -753,14 +822,14 @@ struct SelectedKernel {{
|
||||
"""
|
||||
|
||||
if epilogue == "cshuffle":
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += self.populate_cshuffle_gemm_universal()
|
||||
elif self.kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += self.populate_cshuffle_gemm_multi_d()
|
||||
elif self.kernel_name_prefix == "gemm_preshuffle":
|
||||
instance_code += self.populate_cshuffle_gemm_preshuffle()
|
||||
else: # default epilogue
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += self.populate_default_gemm_universal()
|
||||
elif self.kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += self.populate_default_gemm_multi_d()
|
||||
|
||||
Reference in New Issue
Block a user