[rocm-libraries] ROCm/rocm-libraries#4996 (commit 0a47fbe)

[CK TILE ENGINE] Add grouped_gemm operator to Tile Engine
 (gfx942/gfx950) (#4996)

## Motivation

The grouped_gemm CK Tile kernel exists (e.g.,
`example/17_grouped_gemm/`) but has no Tile Engine wrapper. Grouped GEMM
handles multiple independent GEMM problems with varying M/N/K dimensions
in a single kernel launch. This PR adds the Tile Engine infrastructure
for automated kernel generation, benchmarking, and profiling of grouped
GEMM kernels.

Jira: AICK-809

## Technical Details

- Created Tile Engine wrapper under `tile_engine/ops/gemm/grouped_gemm/`
following the `gemm_universal` template
- Files added: `CMakeLists.txt`, `grouped_gemm_common.hpp`,
`grouped_gemm_benchmark.hpp`, `grouped_gemm_profiler.hpp`,
`grouped_gemm_benchmark.py`, `grouped_gemm_benchmark_single.cpp`,
`grouped_gemm_instance_builder.py`, `configs/`
- Supported datatypes: fp16, fp8, bf16, bf8
- Supported layouts: rcr, rrr, ccr, crr
- Target GPUs: gfx942, gfx950
- CK Tile kernel: `ck_tile::GroupedGemmKernel` from
`include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp`
- Instance builder extends `GemmKernelBuilder` base class
- Registered in `tile_engine/ops/gemm/CMakeLists.txt`
- Updated Jenkinsfile to build and benchmark grouped_gemm targets in CI
- Benchmark infrastructure includes JSON output, CSV export, and
verification support

## Test Plan

- CMake configure succeeds for grouped_gemm targets
- Kernel instance builder generates valid kernel headers for all
(datatype, layout) combinations
- At least one kernel binary compiles and runs per datatype/layout
combination
- Correctness passes with `--verify 1` on gfx942/gfx950

## Test Result

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2026-03-10 23:59:26 +00:00
committed by assistant-librarian[bot]
parent 9f47b8a63d
commit c85c272c39
13 changed files with 2582 additions and 24 deletions

View File

@@ -3,4 +3,5 @@
add_subdirectory(gemm_universal)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_preshuffle)
add_subdirectory(gemm_preshuffle)
add_subdirectory(grouped_gemm)

View File

@@ -168,6 +168,8 @@ class GemmKernelBuilder:
default_pipeline = "compv4"
elif self.kernel_name_prefix == "gemm_preshuffle":
default_pipeline = "preshufflev2"
elif self.kernel_name_prefix == "grouped_gemm":
default_pipeline = "compv4"
configs = []
for tile_m in tile_m_values:
@@ -335,7 +337,11 @@ class GemmKernelBuilder:
kernel_name += f"_{tile_str}"
if self.kernel_name_prefix in ["gemm_universal", "gemm_multi_d"]:
if self.kernel_name_prefix in [
"gemm_universal",
"gemm_multi_d",
"grouped_gemm",
]:
# Map pipeline names to the correct pipeline implementation
pipeline_impl_map = {
"mem": "ck_tile::GemmPipelineAgBgCrMem",
@@ -410,6 +416,11 @@ class GemmKernelBuilder:
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
"""
if self.kernel_name_prefix == "grouped_gemm":
instance_code += """#include <vector>
#include <hip/hip_runtime.h>
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
"""
return instance_code
@@ -425,10 +436,11 @@ class GemmKernelBuilder:
# Assign layouts based on self.layout
if self.kernel_name_prefix == "gemm_multi_d":
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
elif (
self.kernel_name_prefix == "gemm_universal"
or self.kernel_name_prefix == "gemm_preshuffle"
):
elif self.kernel_name_prefix in [
"gemm_universal",
"gemm_preshuffle",
"grouped_gemm",
]:
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
instance_code = f"""
@@ -502,8 +514,12 @@ struct SelectedKernel {{
static constexpr bool TransposeC = false;
static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};"""
if self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
instance_code += f"""
if self.kernel_name_prefix in [
"gemm_universal",
"gemm_preshuffle",
"grouped_gemm",
]:
instance_code += f"""
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
static constexpr bool UseStructuredSparsity = false;
static constexpr ck_tile::index_t NumWaveGroups = 1;"""
@@ -528,9 +544,13 @@ struct SelectedKernel {{
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;"""
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
elif self.kernel_name_prefix in [
"gemm_universal",
"gemm_preshuffle",
"grouped_gemm",
]:
instance_code = """
// Tile shape
using TileShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileM, TileN, TileK>,
@@ -604,6 +624,13 @@ struct SelectedKernel {{
// Launch function
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {"""
elif self.kernel_name_prefix == "grouped_gemm":
instance_code = """
// Launch function
static float launch(const std::vector<ck_tile::GroupedGemmHostArgs<>>& gemm_descs,
const ck_tile::stream_config& stream,
void* kargs_ptr) {"""
# Scheduler initialization
if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]:
@@ -644,12 +671,12 @@ struct SelectedKernel {{
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
# Scheduler initialization
if self.kernel_name_prefix in ["gemm_universal"]:
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += f"""
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""
# UniversalGemmProblem
if self.kernel_name_prefix in ["gemm_universal"]:
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += """
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
@@ -664,7 +691,7 @@ struct SelectedKernel {{
scheduler>;"""
# GemmPipeline
if self.kernel_name_prefix in ["gemm_universal"]:
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += f"""
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
@@ -711,13 +738,13 @@ struct SelectedKernel {{
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
instance_code += f"""
// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
@@ -725,7 +752,7 @@ struct SelectedKernel {{
// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
@@ -733,13 +760,55 @@ struct SelectedKernel {{
<< std::endl;
}}"""
instance_code += f"""
instance_code += f"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}}
}};
"""
elif self.kernel_name_prefix == "grouped_gemm":
instance_code += f"""
// Kernel type
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping grouped gemm!");
}}
// Get grid and block sizes
const dim3 grids = {"Kernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "dim3(kargs.empty() ? 0 : kargs.back().block_end, 1, 1)"};
const dim3 blocks = Kernel::BlockSize();
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
hipMemcpyHostToDevice,
stream.stream_id_));
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
kargs.size()));
return ave_time;
}}
}};
@@ -753,14 +822,14 @@ struct SelectedKernel {{
"""
if epilogue == "cshuffle":
if self.kernel_name_prefix == "gemm_universal":
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += self.populate_cshuffle_gemm_universal()
elif self.kernel_name_prefix == "gemm_multi_d":
instance_code += self.populate_cshuffle_gemm_multi_d()
elif self.kernel_name_prefix == "gemm_preshuffle":
instance_code += self.populate_cshuffle_gemm_preshuffle()
else: # default epilogue
if self.kernel_name_prefix == "gemm_universal":
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
instance_code += self.populate_default_gemm_universal()
elif self.kernel_name_prefix == "gemm_multi_d":
instance_code += self.populate_default_gemm_multi_d()

View File

@@ -0,0 +1,309 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GROUPED_GEMM_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for Grouped GEMM (semicolon-separated)")
set(GROUPED_GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for Grouped GEMM (semicolon-separated)")
set(GROUPED_GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GROUPED_GEMM "Enable ccache for Grouped GEMM ops compilation" OFF)
# Store the directory path for use in functions
set(GROUPED_GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual Grouped GEMM targets
function(create_individual_grouped_gemm_target datatype layout trait tile_config config_json)
# Use the parent scope GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL variable
if(NOT GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual Grouped GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_grouped_gemm_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/grouped_gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GROUPED_GEMM_SOURCE_DIR}/grouped_gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "grouped_gemm_${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
--gpu_target "${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}"
DEPENDS ${GROUPED_GEMM_SOURCE_DIR}/grouped_gemm_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
EXCLUDE_FROM_ALL
${GROUPED_GEMM_SOURCE_DIR}/grouped_gemm_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GROUPED_GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GROUPED_GEMM_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_grouped_gemm_all ${target_name})
add_dependencies(benchmark_grouped_gemm_${datatype} ${target_name})
add_dependencies(benchmark_grouped_gemm_${layout} ${target_name})
add_dependencies(benchmark_grouped_gemm_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_grouped_gemm_${pipeline}_pipeline ${target_name})
add_dependencies(benchmark_grouped_gemm_${epilogue}_epilogue ${target_name})
add_dependencies(benchmark_grouped_gemm_${scheduler}_scheduler ${target_name})
endfunction()
# Function to build individual Grouped GEMM targets
function(build_individual_grouped_gemm_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GROUPED_GEMM_CONFIG_FILE
# 2. CMake variable GROUPED_GEMM_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GROUPED_GEMM_CONFIG_FILE} AND NOT "$ENV{GROUPED_GEMM_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GROUPED_GEMM_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(VERBOSE " Using config from environment variable: ${config_filename}")
elseif(NOT "${GROUPED_GEMM_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GROUPED_GEMM_CONFIG_FILE}")
message(VERBOSE " Using custom config: ${GROUPED_GEMM_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(VERBOSE " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(VERBOSE " Working path: ${working_path}")
message(VERBOSE " Config file: ${json_blob}")
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/grouped_gemm_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/grouped_gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels ")
# First, just list the kernels (fast operation)
message(VERBOSE " Listing kernel configurations...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/grouped_gemm_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/grouped_gemm_kernel_count.txt)
file(READ ${working_path}/grouped_gemm_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(VERBOSE " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/grouped_gemm_kernel_list.txt)
file(STRINGS ${working_path}/grouped_gemm_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_grouped_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(VERBOSE "=== Starting Tile Engine Grouped GEMM Configuration ===")
message(VERBOSE "GROUPED_GEMM_DATATYPE: ${GROUPED_GEMM_DATATYPE}")
message(VERBOSE "GROUPED_GEMM_LAYOUT: ${GROUPED_GEMM_LAYOUT}")
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx942, gfx950
set(GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx942;gfx950")
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL ${target})
message(VERBOSE " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(VERBOSE "Building individual Grouped GEMM targets for GPU targets: ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GROUPED_GEMM)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(VERBOSE "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(VERBOSE "ccache disabled for Grouped GEMM ops (use -DENABLE_CCACHE_GROUPED_GEMM=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_grouped_gemm_all)
# Create datatype collection targets
foreach(dt IN LISTS GROUPED_GEMM_DATATYPE)
add_custom_target(benchmark_grouped_gemm_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GROUPED_GEMM_LAYOUT)
add_custom_target(benchmark_grouped_gemm_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GROUPED_GEMM_DATATYPE)
foreach(l IN LISTS GROUPED_GEMM_LAYOUT)
add_custom_target(benchmark_grouped_gemm_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all Grouped GEMM kernels
set(GROUPED_GEMM_PIPELINES "mem;compv3;compv4")
set(GROUPED_GEMM_EPILOGUES "default;cshuffle")
set(GROUPED_GEMM_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GROUPED_GEMM_PIPELINES)
add_custom_target(benchmark_grouped_gemm_${pipeline}_pipeline)
endforeach()
foreach(epilogue IN LISTS GROUPED_GEMM_EPILOGUES)
add_custom_target(benchmark_grouped_gemm_${epilogue}_epilogue)
endforeach()
foreach(scheduler IN LISTS GROUPED_GEMM_SCHEDULERS)
add_custom_target(benchmark_grouped_gemm_${scheduler}_scheduler)
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GROUPED_GEMM_DATATYPE)
foreach(l IN LISTS GROUPED_GEMM_LAYOUT)
build_individual_grouped_gemm_targets(${dt} ${l})
endforeach()
endforeach()
endif()

View File

@@ -0,0 +1,92 @@
{
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 256
},
"tile_n": {
"max": 256,
"min": 64,
"step": 256
},
"tile_k": {
"max": 256,
"min": 64,
"step": 256
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
32
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
},
"k_block_per_cu": 1
}

View File

@@ -0,0 +1,104 @@
{
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 64
},
"tile_n": {
"max": 256,
"min": 64,
"step": 64
},
"tile_k": {
"max": 256,
"min": 64,
"step": 64
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle",
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false,
true
]
}
},
"k_block_per_cu": 1
}

View File

@@ -0,0 +1,87 @@
{
"tile_config": {
"tile_m": {
"values": [
128
]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
64
]
},
"warp_m": {
"values": [
4
]
},
"warp_n": {
"values": [
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"mem"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"default"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
true
]
}
},
"k_block_per_cu": 2
}

View File

@@ -0,0 +1,266 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iomanip>
#include <vector>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm_common.hpp"
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
// Data types and Layouts are defined by the generated kernel headers
// No hardcoded type definitions here to avoid conflicts
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct GroupedGemmProblem
{
int group_count_;
int kbatch_;
std::vector<int> Ms_, Ns_, Ks_;
std::vector<int> stride_As_, stride_Bs_, stride_Cs_;
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
std::string layout_a_, layout_b_, layout_c_;
friend std::ostream& operator<<(std::ostream& os, const GroupedGemmProblem& problem)
{
os << "{\n"
<< " \"group_count\":" << problem.group_count_ << ",\n"
<< " \"kbatch\":" << problem.kbatch_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c_ << "\"\n"
<< "}";
return os;
}
};
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
GroupedGemmProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \"" << obj.name_ << "\",\n"
<< " \"problem\": " << obj.problem_ << ",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
struct Setting
{
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;
int verify_;
int init_method_;
bool log_;
std::string csv_filename_;
bool flush_cache_;
int rotating_count_;
bool json_output_;
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations for a single group
bool compare_single(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
/// @brief Function to compare grouped gemm results across all groups
bool compare_grouped(std::string instanceName,
const GroupedGemmProblem& problem,
std::vector<ck_tile::HostTensor<CDataType>>& c_dev_results,
std::vector<ck_tile::HostTensor<CDataType>>& c_host_results)
{
bool pass = true;
for(int i = 0; i < problem.group_count_; ++i)
{
pass &= compare_single(instanceName + "[" + std::to_string(i) + "]",
problem.Ks_[i],
problem.kbatch_,
c_dev_results[i],
c_host_results[i]);
}
return pass;
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU for all groups
void gemm_host_reference_grouped(int verify,
const GroupedGemmProblem& problem,
std::vector<ck_tile::HostTensor<ADataType>>& a_tensors,
std::vector<ck_tile::HostTensor<BDataType>>& b_tensors,
std::vector<ck_tile::HostTensor<CDataType>>& c_host_results,
std::vector<std::unique_ptr<ck_tile::DeviceMem>>& a_dev_bufs,
std::vector<std::unique_ptr<ck_tile::DeviceMem>>& b_dev_bufs)
{
const int group_count = problem.group_count_;
if(verify == 1)
{
for(int i = 0; i < group_count; ++i)
{
c_host_results[i].SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_tensors[i], b_tensors[i], c_host_results[i]);
}
}
else if(verify == 2)
{
for(int i = 0; i < group_count; ++i)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
b_dev_bufs[i]->ToDevice(b_tensors[i].data());
}
ck_tile::DeviceMem c_gpu_buf_ref(c_host_results[i].get_element_space_size_in_bytes());
c_host_results[i].SetZero();
c_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_dev_bufs[i]->GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_dev_bufs[i]->GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A,
d_B,
d_C,
problem.Ms_[i],
problem.Ns_[i],
problem.Ks_[i],
problem.stride_As_[i],
problem.stride_Bs_[i],
problem.stride_Cs_[i]);
c_gpu_buf_ref.FromDevice(c_host_results[i].data());
}
}
}
#pragma clang diagnostic pop

View File

@@ -0,0 +1,703 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import sys
import json
import subprocess
import argparse
import csv
import time
from pathlib import Path
from typing import List, Dict, Tuple, Optional
class GroupedGemmBenchmark:
def __init__(self, build_dir: str, verbose: bool = False):
self.build_dir = Path(build_dir)
self.verbose = verbose
self.results = []
def discover_kernels(self) -> List[Path]:
"""Find all benchmark_grouped_gemm_* executables in the build directory"""
bin_dir = self.build_dir / "bin"
if not bin_dir.exists():
print(f"Error: Binary directory {bin_dir} does not exist")
return []
kernels = list(bin_dir.glob("benchmark_grouped_gemm_*"))
if self.verbose:
print(f"Found {len(kernels)} kernel executables")
for k in kernels:
print(f" - {k.name}")
return kernels
def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]:
"""Extract comprehensive kernel information from filename"""
name = kernel_path.stem
# Initialize with basic info
info = {
"executable": str(kernel_path),
"name": name,
"data_type": "unknown",
"layout": "unknown",
"pipeline": "unknown",
"scheduler": "unknown",
"epilogue": "unknown",
}
# Parse the kernel name pattern:
# benchmark_grouped_gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_True_64x64x64_2x2x1_16x16x32
parts = name.split("_")
if len(parts) >= 4:
# Extract data type (4th part after benchmark_grouped_gemm_)
info["data_type"] = parts[3] if len(parts) > 3 else "unknown"
# Extract layout (5th part)
info["layout"] = parts[4] if len(parts) > 4 else "unknown"
# Extract pipeline (6th part)
info["pipeline"] = parts[5] if len(parts) > 5 else "unknown"
# Extract epilogue (7th part)
info["epilogue"] = parts[6] if len(parts) > 6 else "unknown"
# Extract scheduler (8th part)
info["scheduler"] = parts[7] if len(parts) > 7 else "unknown"
# Extract detailed configuration from the end of the name
config_info = self.parse_detailed_config(name)
info.update(config_info)
# Generate config ID
info["config_id"] = self.generate_config_id(info)
return info
def parse_detailed_config(self, kernel_name: str) -> Dict:
"""Parse detailed configuration from kernel name"""
config = {
"tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0},
"warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0},
"warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0},
"optimization_flags": {
"pad_m": False,
"pad_n": False,
"pad_k": False,
"persistent": False,
},
}
# Split by underscore and look for patterns
parts = kernel_name.split("_")
# Look for boolean flags (sequence of True/False values)
bool_sequence = []
for i, part in enumerate(parts):
if part in ["True", "False"]:
bool_sequence.append(part == "True")
# Continue collecting consecutive boolean values
j = i + 1
while j < len(parts) and parts[j] in ["True", "False"]:
bool_sequence.append(parts[j] == "True")
j += 1
break
# Assign boolean flags if we found them
# Order: pad_m, pad_n, pad_k, persistent (4 flags total)
if len(bool_sequence) >= 4:
config["optimization_flags"]["pad_m"] = bool_sequence[0]
config["optimization_flags"]["pad_n"] = bool_sequence[1]
config["optimization_flags"]["pad_k"] = bool_sequence[2]
config["optimization_flags"]["persistent"] = bool_sequence[3]
# Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16)
# The pattern is: tile_sizes_warp_config_warp_tile
dimension_groups = []
for part in parts:
if "x" in part and len(part.split("x")) == 3:
try:
dims = [int(x) for x in part.split("x")]
if all(d > 0 for d in dims):
dimension_groups.append(dims)
except ValueError:
continue
# Assign dimensions based on order and magnitude
if len(dimension_groups) >= 3:
# Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
# Largest dimensions = tile sizes
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
# Smallest dimensions = warp config
config["warp_config"]["warp_m"] = sorted_groups[2][0]
config["warp_config"]["warp_n"] = sorted_groups[2][1]
config["warp_config"]["warp_k"] = sorted_groups[2][2]
# Middle dimensions = warp tile
config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0]
config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1]
config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2]
elif len(dimension_groups) == 2:
# If only 2 groups, assign based on magnitude
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
# Larger = tile sizes
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
# Smaller = warp config
config["warp_config"]["warp_m"] = sorted_groups[1][0]
config["warp_config"]["warp_n"] = sorted_groups[1][1]
config["warp_config"]["warp_k"] = sorted_groups[1][2]
elif len(dimension_groups) == 1:
# Only one group - assume it's tile sizes
config["tile_sizes"]["tile_m"] = dimension_groups[0][0]
config["tile_sizes"]["tile_n"] = dimension_groups[0][1]
config["tile_sizes"]["tile_k"] = dimension_groups[0][2]
return config
def generate_config_id(self, info: Dict) -> str:
"""Generate a compact config ID from kernel info"""
# Create a compact identifier
parts = [
info.get("data_type", "unk"),
info.get("layout", "unk"),
info.get("pipeline", "unk"),
info.get("scheduler", "unk"),
]
# Add tile configuration if available
tile_sizes = info.get("tile_sizes", {})
if tile_sizes.get("tile_m", 0) > 0:
tile_str = (
f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}"
)
parts.append(tile_str)
# Add warp config if available
warp_config = info.get("warp_config", {})
if warp_config.get("warp_m", 0) > 0:
warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}"
parts.append(warp_str)
# Add warp tile if available
warp_tile = info.get("warp_tile", {})
if warp_tile.get("warp_tile_m", 0) > 0:
warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}"
parts.append(warp_tile_str)
return "_".join(parts)
def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]:
"""Run a single kernel with given parameters and save output to individual JSON file"""
# Create results directory
results_dir = self.build_dir / "results"
results_dir.mkdir(exist_ok=True)
# Generate unique JSON filename for this kernel
json_file = results_dir / f"{kernel_path.stem}.json"
cmd = [str(kernel_path)]
# Add parameters
for key, value in params.items():
cmd.append(f"-{key}={value}")
# Add JSON output flag for clean JSON output
cmd.append("-json_output=true")
if self.verbose:
print(f"Running: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
if result.returncode != 0:
print(f"Error running {kernel_path.name}: {result.stderr}")
return None
# Save raw output to individual JSON file
output = result.stdout.strip()
if output:
with open(json_file, "w") as f:
f.write(output)
# Parse the JSON file
return self.parse_json_file(json_file)
else:
print(f"No output from {kernel_path.name}")
return None
except subprocess.TimeoutExpired:
print(f"Timeout running {kernel_path.name}")
return None
except Exception as e:
print(f"Error running {kernel_path.name}: {e}")
return None
def parse_json_file(self, json_file: Path) -> Optional[Dict]:
"""Parse JSON data from individual kernel output file"""
try:
with open(json_file, "r") as f:
content = f.read().strip()
# Parse the JSON directly since executables produce clean JSON
data = json.loads(content)
# Return the complete JSON data as-is, just add some convenience fields
result = data.copy()
if "perf_result" in data:
perf = data["perf_result"]
# Add convenience fields for backward compatibility
result["time_ms"] = perf.get("latency(ms)", 0)
result["tflops"] = perf.get("tflops(TFlops)", 0)
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
return result
except json.JSONDecodeError as e:
if self.verbose:
print(f"Failed to parse JSON from {json_file}: {e}")
return None
except Exception as e:
if self.verbose:
print(f"Error reading JSON file {json_file}: {e}")
return None
def benchmark_problem_size(
self,
kernels: List[Path],
group_count: int = 8,
m: int = 3840,
n: int = 4096,
k: int = 2048,
kbatch: int = 1,
verify: int = 0,
warmup: int = 50,
repeat: int = 100,
flush_cache: bool = True,
rotating_count: int = 1000,
) -> List[Dict]:
"""Benchmark all kernels for a specific problem size"""
results = []
params = {
"group_count": group_count,
"m": m,
"n": n,
"k": k,
"kbatch": kbatch,
"verify": verify,
"warmup": warmup,
"repeat": repeat,
"flush_cache": str(flush_cache).lower(),
"rotating_count": rotating_count,
}
print(
f"\nBenchmarking group_count={group_count}, M={m}, N={n}, K={k}, kbatch={kbatch}"
)
for kernel_path in kernels:
kernel_info = self.extract_kernel_info(kernel_path)
result = self.run_kernel(kernel_path, params)
if result:
# Create new structured result format
structured_result = {
"name": kernel_info["name"],
"config_id": kernel_info["config_id"],
"problem": result.get("problem", {}),
"perf_result": result.get("perf_result", {}),
"config": {
"data_type": kernel_info["data_type"],
"layout": kernel_info["layout"],
"pipeline": kernel_info["pipeline"],
"scheduler": kernel_info["scheduler"],
"epilogue": kernel_info["epilogue"],
"tile_sizes": kernel_info.get("tile_sizes", {}),
"warp_config": kernel_info.get("warp_config", {}),
"warp_tile": kernel_info.get("warp_tile", {}),
"optimization_flags": kernel_info.get("optimization_flags", {}),
},
"executable": kernel_info["executable"],
# Keep backward compatibility fields
"time_ms": result.get("time_ms", 0),
"tflops": result.get("tflops", 0),
"bandwidth_gb_s": result.get("bandwidth_gb_s", 0),
}
results.append(structured_result)
if self.verbose:
print(
f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms"
)
return results
def find_best_kernel(
self, results: List[Dict], metric: str = "tflops"
) -> Optional[Dict]:
"""Find the best performing kernel based on metric"""
if not results:
return None
if metric == "tflops":
return max(results, key=lambda x: x.get("tflops", 0))
elif metric == "time_ms":
return min(results, key=lambda x: x.get("time_ms", float("inf")))
elif metric == "bandwidth_gb_s":
return max(results, key=lambda x: x.get("bandwidth_gb_s", 0))
else:
raise ValueError(f"Unknown metric: {metric}")
def benchmark_sweep(
self,
problem_sizes: List[Tuple[int, int, int]],
group_counts: List[int] = [8],
kbatch_values: List[int] = [1],
verify: bool = False,
warmup: int = 50,
repeat: int = 100,
flush_cache: bool = True,
rotating_count: int = 1000,
) -> Dict:
"""Run comprehensive benchmark sweep"""
kernels = self.discover_kernels()
if not kernels:
print("No kernels found!")
return {}
all_results = []
best_kernels = {}
for m, n, k in problem_sizes:
for group_count in group_counts:
for kbatch in kbatch_values:
results = self.benchmark_problem_size(
kernels,
group_count=group_count,
m=m,
n=n,
k=k,
kbatch=kbatch,
verify=1 if verify else 0,
warmup=warmup,
repeat=repeat,
flush_cache=flush_cache,
rotating_count=rotating_count,
)
all_results.extend(results)
# Find best kernel for this configuration
best = self.find_best_kernel(results)
if best:
key = f"g{group_count}_m{m}_n{n}_k{k}_kbatch{kbatch}"
best_kernels[key] = best
print(
f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)"
)
self.results = all_results
return best_kernels
def export_csv(self, filename: str):
"""Export all results to CSV"""
if not self.results:
print("No results to export")
return
# Get all unique keys from results
all_keys = set()
for result in self.results:
all_keys.update(result.keys())
# Sort keys for consistent output
fieldnames = sorted(all_keys)
with open(filename, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(self.results)
print(f"Results exported to {filename}")
def export_best_kernels(self, best_kernels: Dict, filename: str):
"""Export best kernel selections to file"""
with open(filename, "w") as f:
f.write("# Best kernel selections for grouped GEMM\n")
f.write(
"# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n"
)
for key, kernel in sorted(best_kernels.items()):
f.write(
f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n"
)
print(f"Best kernels exported to {filename}")
def export_json(self, filename: str, best_kernels: Dict = None):
"""Export all results and best kernels to JSON with comprehensive metadata"""
from datetime import datetime
# Calculate comprehensive summary statistics for all metrics
successful_results = [r for r in self.results if r.get("tflops", 0) > 0]
tflops_values = [r.get("tflops", 0) for r in successful_results]
bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results]
latency_values = [
r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0
]
# Performance breakdown by kernel type
pipeline_stats = {}
scheduler_stats = {}
data_type_stats = {}
for result in successful_results:
# Get config info from the new structure
config = result.get("config", {})
# Pipeline statistics
pipeline = config.get("pipeline", "unknown")
if pipeline not in pipeline_stats:
pipeline_stats[pipeline] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
pipeline_stats[pipeline]["count"] += 1
pipeline_stats[pipeline]["best_tflops"] = max(
pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0)
)
# Scheduler statistics
scheduler = config.get("scheduler", "unknown")
if scheduler not in scheduler_stats:
scheduler_stats[scheduler] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
scheduler_stats[scheduler]["count"] += 1
scheduler_stats[scheduler]["best_tflops"] = max(
scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0)
)
# Data type statistics
data_type = config.get("data_type", "unknown")
if data_type not in data_type_stats:
data_type_stats[data_type] = {
"count": 0,
"avg_tflops": 0,
"best_tflops": 0,
}
data_type_stats[data_type]["count"] += 1
data_type_stats[data_type]["best_tflops"] = max(
data_type_stats[data_type]["best_tflops"], result.get("tflops", 0)
)
# Calculate averages for breakdown stats
for stats_dict, field_name in [
(pipeline_stats, "pipeline"),
(scheduler_stats, "scheduler"),
(data_type_stats, "data_type"),
]:
for key in stats_dict:
relevant_results = [
r
for r in successful_results
if r.get("config", {}).get(field_name, "unknown") == key
]
if relevant_results:
stats_dict[key]["avg_tflops"] = sum(
r.get("tflops", 0) for r in relevant_results
) / len(relevant_results)
output_data = {
"benchmark_metadata": {
"timestamp": datetime.now().isoformat(),
"total_kernels_tested": len(self.results),
"unique_kernels": len(
set(r.get("name", "unknown") for r in self.results)
),
"successful_runs": len(successful_results),
"failed_runs": len(self.results) - len(successful_results),
},
"performance_summary": {
"tflops_stats": {
"best": max(tflops_values, default=0),
"average": sum(tflops_values) / len(tflops_values)
if tflops_values
else 0,
"min": min(tflops_values, default=0),
"median": sorted(tflops_values)[len(tflops_values) // 2]
if tflops_values
else 0,
},
"bandwidth_stats": {
"best_gb_s": max(bandwidth_values, default=0),
"average_gb_s": sum(bandwidth_values) / len(bandwidth_values)
if bandwidth_values
else 0,
"min_gb_s": min(bandwidth_values, default=0),
"median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2]
if bandwidth_values
else 0,
},
"latency_stats": {
"best_ms": min(latency_values, default=0),
"average_ms": sum(latency_values) / len(latency_values)
if latency_values
else 0,
"max_ms": max(latency_values, default=0),
"median_ms": sorted(latency_values)[len(latency_values) // 2]
if latency_values
else 0,
},
"kernel_type_breakdown": {
"by_pipeline": pipeline_stats,
"by_scheduler": scheduler_stats,
"by_data_type": data_type_stats,
},
"total_problem_configurations": len(best_kernels)
if best_kernels
else 0,
},
"kernel_results": self.results,
"best_kernels_by_problem": best_kernels or {},
}
with open(filename, "w") as f:
json.dump(output_data, f, indent=2)
print(f"JSON results exported to {filename}")
print(f" - Total kernels: {len(self.results)}")
print(f" - Successful runs: {len(successful_results)}")
print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}")
print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s")
print(f" - Best latency: {min(latency_values, default=0):.2f}ms")
def main():
parser = argparse.ArgumentParser(
description="Grouped GEMM Kernel Benchmarking Tool"
)
parser.add_argument(
"build_dir", help="Build directory containing kernel executables"
)
parser.add_argument(
"--problem-sizes",
nargs="+",
default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"],
help="Default problem sizes as M,N,K tuples (used for all groups)",
)
parser.add_argument(
"--group-counts",
nargs="+",
type=int,
default=[8],
help="Group count values to test (default: 8)",
)
parser.add_argument(
"--kbatch",
nargs="+",
type=int,
default=[1],
help="K-batch (SplitK) values to test",
)
parser.add_argument("--verify", action="store_true", help="Enable verification")
parser.add_argument(
"--csv",
default="grouped_gemm_benchmark_results.csv",
help="CSV output filename",
)
parser.add_argument(
"--best",
default="best_grouped_gemm_kernels.txt",
help="Best kernels output filename",
)
parser.add_argument("--verbose", action="store_true", help="Verbose output")
parser.add_argument(
"--warmup",
type=int,
default=50,
help="Number of warmup iterations (default: 50)",
)
parser.add_argument(
"--repeat",
type=int,
default=100,
help="Number of benchmark iterations (default: 100)",
)
parser.add_argument(
"--flush-cache",
action="store_true",
default=True,
help="Enable cache flushing (default: True)",
)
parser.add_argument(
"--rotating-count",
type=int,
default=1000,
help="Number of iterations to rotate cache (default: 1000)",
)
parser.add_argument("--json", help="JSON output filename (optional)")
args = parser.parse_args()
# Parse problem sizes
problem_sizes = []
for size_str in args.problem_sizes:
try:
m, n, k = map(int, size_str.split(","))
problem_sizes.append((m, n, k))
except ValueError:
print(f"Invalid problem size: {size_str}")
return 1
# Create benchmark instance
benchmark = GroupedGemmBenchmark(args.build_dir, verbose=args.verbose)
# Run benchmark sweep
print("Starting Grouped GEMM kernel benchmark sweep...")
start_time = time.time()
best_kernels = benchmark.benchmark_sweep(
problem_sizes=problem_sizes,
group_counts=args.group_counts,
kbatch_values=args.kbatch,
verify=args.verify,
warmup=args.warmup,
repeat=args.repeat,
flush_cache=args.flush_cache,
rotating_count=args.rotating_count,
)
elapsed_time = time.time() - start_time
print(f"\nBenchmark completed in {elapsed_time:.2f} seconds")
# Export results
benchmark.export_csv(args.csv)
benchmark.export_best_kernels(best_kernels, args.best)
# Export JSON if requested
if args.json:
benchmark.export_json(args.json, best_kernels)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,195 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <functional>
#include <tuple>
#include <exception>
#include <sstream>
#include <vector>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm_profiler.hpp"
#include "grouped_gemm_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in grouped_gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser
.insert("m", "3840", "Default M for all groups (if Ms not specified). Default is 3840.")
.insert("n", "4096", "Default N for all groups (if Ns not specified). Default is 4096.")
.insert("k", "2048", "Default K for all groups (if Ks not specified). Default is 2048.")
.insert("Ms", "", "Comma-separated M dimensions per group.")
.insert("Ns", "", "Comma-separated N dimensions per group.")
.insert("Ks", "", "Comma-separated K dimensions per group.")
.insert("stride_As", "", "Comma-separated stride values for tensor A per group.")
.insert("stride_Bs", "", "Comma-separated stride values for tensor B per group.")
.insert("stride_Cs", "", "Comma-separated stride values for tensor C per group.")
.insert("group_count", "8", "Number of groups. Default is 8.")
.insert("kbatch", "1", "SplitK batch count. Default is 1.")
.insert("verify",
"2",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, "
"2 for validation on GPU. Default is 2, GPU validation.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert(
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
.insert(
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"true",
"To flush cache, possible values are true or false. "
"Default is true.")
.insert(
"rotating_count", "1000", "number of iterations to rotate the cache. default is 1000.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"",
"The filename of benchmark result. Default is empty (no CSV output).")
.insert("json_output",
"false",
"Whether to output results in JSON format only. Possible values are true or false. "
"Default is "
"false");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;
std::string layout_b = BLayout::name;
std::string layout_c = CLayout::name;
const int group_count = arg_parser.get_int("group_count");
const int kbatch = arg_parser.get_int("kbatch");
// Parse per-group dimensions
std::vector<int> Ms = arg_parser.get_int_vec("Ms");
std::vector<int> Ns = arg_parser.get_int_vec("Ns");
std::vector<int> Ks = arg_parser.get_int_vec("Ks");
std::vector<int> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<int> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<int> stride_Cs = arg_parser.get_int_vec("stride_Cs");
// If Ms/Ns/Ks not provided or wrong size, use -m/-n/-k defaults for all groups
const auto gc_size = static_cast<std::size_t>(group_count);
if(group_count == 0 || Ms.size() != gc_size || Ns.size() != gc_size || Ks.size() != gc_size)
{
const int default_m = arg_parser.get_int("m");
const int default_n = arg_parser.get_int("n");
const int default_k = arg_parser.get_int("k");
Ms.assign(group_count, default_m);
Ns.assign(group_count, default_n);
Ks.assign(group_count, default_k);
}
// Default stride vectors to 0 independently if missing or wrong size
if(stride_As.size() != gc_size)
stride_As.assign(group_count, 0);
if(stride_Bs.size() != gc_size)
stride_Bs.assign(group_count, 0);
if(stride_Cs.size() != gc_size)
stride_Cs.assign(group_count, 0);
// Create GroupedGemmProblem struct
GroupedGemmProblem problem{group_count,
kbatch,
Ms,
Ns,
Ks,
stride_As,
stride_Bs,
stride_Cs,
dtype_a,
dtype_b,
dtype_acc,
dtype_c,
layout_a,
layout_b,
layout_c};
// Create Setting struct
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count"),
arg_parser.get_bool("json_output")};
// Get the profiler instance
auto& profiler = GroupedGemmProfiler::instance(setting);
try
{
// Create a lambda that wraps the kernel launch
auto kernel_func = [](const std::vector<ck_tile::GroupedGemmHostArgs<>>& descs,
const ck_tile::stream_config& stream,
void* kargs_ptr) {
return SelectedKernel::launch(descs, stream, kargs_ptr);
};
// Benchmark the kernel
profiler.benchmark(problem, kernel_func);
// Select best instance based on metric
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
benchmark_single(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,100 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
//[TODO] This can be moved to commons
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};

View File

@@ -0,0 +1,303 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import argparse
import importlib.util
import multiprocessing
import concurrent.futures
def _import_gemm_kernel_builder():
"""Import validation utilities from commons directory."""
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"gemm_instance_builder",
os.path.join(parent_dir, "gemm_instance_builder.py"),
)
gemm_builder_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gemm_builder_module)
return gemm_builder_module.GemmKernelBuilder
GemmKernelBuilder = _import_gemm_kernel_builder()
class GroupedGemmKernelBuilder(GemmKernelBuilder):
def __init__(
self,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json=None,
):
super().__init__(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
def _generate_all_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.kernel_name_prefix,
self.working_path,
self.gpu_target,
self.datatype,
self.layout,
self.config_json,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
(
tile_config,
trait_combo,
kernel_name_prefix,
working_path,
gpu_target,
datatype,
layout,
config_json,
) = work_item
# Create a temporary builder instance for this worker
builder = GroupedGemmKernelBuilder(
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "grouped_gemm_" prefix
simplified_name = kernel_name
if simplified_name.startswith("grouped_gemm_"):
simplified_name = simplified_name[
len(kernel_name_prefix) + 1 :
] # Remove "grouped_gemm" prefix
# Write individual header file
header_file = working_path / f"grouped_gemm_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="Grouped GEMM kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--gpu_target",
required=True,
help="GPU target architecture",
)
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcr", "rrr", "ccr", "crr"],
help="Matrix layout",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
if args.datatype not in ["fp16", "bf16", "fp8", "bf8"]:
parser.error(
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_str = args.layout.lower()
if len(layout_str) != 3:
parser.error(
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
matrix_a_layout = layout_str[0]
matrix_b_layout = layout_str[1]
matrix_c_layout = layout_str[2]
if matrix_a_layout not in ["r", "c"] or matrix_b_layout not in ["r", "c"]:
parser.error(
f"Invalid matrix_a layout : {matrix_a_layout} or matrix_b layout: {matrix_b_layout} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
if matrix_c_layout != "r":
parser.error(
f"Invalid matrix_c layout: {matrix_c_layout} (must be 'r' only as currently we are supporting only row major)"
)
kernel_name_prefix = "grouped_gemm"
builder = GroupedGemmKernelBuilder(
kernel_name_prefix,
args.working_path,
args.gpu_target,
args.datatype,
args.layout,
args.config_json,
)
if args.list_kernels:
builder._list_kernels()
elif args.gen_single:
# Generate a single kernel file input validation
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3] == "True", # pad_m
trait_parts[4] == "True", # pad_n
trait_parts[5] == "True", # pad_k
trait_parts[6] == "True", # persistent
)
# Generate the kernel
builder._generate_kernel_instance(
tile_config,
trait_combo,
)
elif args.gen_all_individual:
# Generate all individual kernel files
builder._generate_all_individual(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,326 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <fstream>
#include <iomanip>
#include <memory>
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "grouped_gemm_benchmark.hpp"
class GroupedGemmProfiler
{
public:
static GroupedGemmProfiler& instance(Setting setting)
{
static GroupedGemmProfiler instance{setting};
return instance;
}
// Overload for single kernel benchmarking
void benchmark(GroupedGemmProblem& problem,
std::function<float(const std::vector<ck_tile::GroupedGemmHostArgs<>>&,
const ck_tile::stream_config&,
void*)> kernel_func)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(
std::vector<ck_tile::GroupedGemmHostArgs<>>&, const ck_tile::stream_config&, void*)>>
callables;
callables.push_back([kernel_func](std::vector<ck_tile::GroupedGemmHostArgs<>>& descs,
const ck_tile::stream_config& stream,
void* kargs_ptr) {
float time = kernel_func(descs, stream, kargs_ptr);
return std::make_tuple(std::string(KERNEL_NAME), time);
});
benchmark(problem, callables);
}
void benchmark(
GroupedGemmProblem& problem,
std::vector<std::function<std::tuple<std::string, float>(
std::vector<ck_tile::GroupedGemmHostArgs<>>&, const ck_tile::stream_config&, void*)>>&
callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
const int group_count = problem.group_count_;
// Compute default strides for each group
for(int i = 0; i < group_count; ++i)
{
problem.stride_As_[i] = ck_tile::get_default_stride(
problem.Ms_[i], problem.Ks_[i], problem.stride_As_[i], is_row_major(layout_a));
problem.stride_Bs_[i] = ck_tile::get_default_stride(
problem.Ks_[i], problem.Ns_[i], problem.stride_Bs_[i], is_row_major(layout_b));
problem.stride_Cs_[i] = ck_tile::get_default_stride(
problem.Ms_[i], problem.Ns_[i], problem.stride_Cs_[i], is_row_major(layout_c));
}
// Create per-group tensors
std::vector<ck_tile::HostTensor<ADataType>> a_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_dev_results;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_dev_results.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_dev_bufs;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_dev_bufs;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_dev_bufs;
a_dev_bufs.reserve(group_count);
b_dev_bufs.reserve(group_count);
c_dev_bufs.reserve(group_count);
std::vector<ck_tile::GroupedGemmHostArgs<>> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = problem.Ms_[i];
const ck_tile::index_t N = problem.Ns_[i];
const ck_tile::index_t K = problem.Ks_[i];
a_tensors.push_back(ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
M, K, problem.stride_As_[i], is_row_major(layout_a))));
b_tensors.push_back(ck_tile::HostTensor<BDataType>(ck_tile::host_tensor_descriptor(
K, N, problem.stride_Bs_[i], is_row_major(layout_b))));
c_dev_results.push_back(ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
M, N, problem.stride_Cs_[i], is_row_major(layout_c))));
if(setting_.init_method_ == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_tensors[i]);
}
else if(setting_.init_method_ == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_tensors[i]);
ck_tile::FillMonotonicSeq<BDataType>{}(b_tensors[i]);
}
else if(setting_.init_method_ == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_tensors[i]);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_tensors[i]);
}
else
{
a_tensors[i].SetZero();
b_tensors[i].SetZero();
}
a_dev_bufs.push_back(std::make_unique<ck_tile::DeviceMem>(
a_tensors[i].get_element_space_size_in_bytes()));
b_dev_bufs.push_back(std::make_unique<ck_tile::DeviceMem>(
b_tensors[i].get_element_space_size_in_bytes()));
c_dev_bufs.push_back(std::make_unique<ck_tile::DeviceMem>(
c_dev_results[i].get_element_space_size_in_bytes()));
a_dev_bufs[i]->ToDevice(a_tensors[i].data());
b_dev_bufs[i]->ToDevice(b_tensors[i].data());
c_dev_bufs[i]->SetZero();
c_dev_results[i].SetZero();
const void* p_a = a_dev_bufs[i]->GetDeviceBuffer();
const void* p_b = b_dev_bufs[i]->GetDeviceBuffer();
void* p_c = c_dev_bufs[i]->GetDeviceBuffer();
gemm_descs.push_back({p_a,
p_b,
{/*ds_ptr*/},
p_c,
problem.kbatch_,
M,
N,
K,
problem.stride_As_[i],
problem.stride_Bs_[i],
{/*stride_Ds*/},
problem.stride_Cs_[i]});
}
// Allocate workspace for kernel args
ck_tile::DeviceMem workspace(gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>));
// Compute host reference for verification
std::vector<ck_tile::HostTensor<CDataType>> c_host_results;
if(setting_.verify_)
{
c_host_results.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
c_host_results.push_back(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(problem.Ms_[i],
problem.Ns_[i],
problem.stride_Cs_[i],
is_row_major(layout_c))));
}
gemm_host_reference_grouped(setting_.verify_,
problem,
a_tensors,
b_tensors,
c_host_results,
a_dev_bufs,
b_dev_bufs);
}
for(auto& callable : callables)
{
auto kernel_run_result = callable(gemm_descs,
ck_tile::stream_config{nullptr,
true,
setting_.log_,
setting_.n_warmup_,
setting_.n_repeat_,
setting_.is_gpu_timer_,
setting_.flush_cache_,
setting_.rotating_count_},
workspace.GetDeviceBuffer());
process_result(problem, c_dev_bufs, c_host_results, c_dev_results, kernel_run_result);
}
}
void process_result(const GroupedGemmProblem& problem,
std::vector<std::unique_ptr<ck_tile::DeviceMem>>& c_dev_bufs,
std::vector<ck_tile::HostTensor<CDataType>>& c_host_results,
std::vector<ck_tile::HostTensor<CDataType>>& c_dev_results,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
KernelInstance kernel_instance{name, problem, {-1.0f, -1.0f, -1.0f}};
// Compute performance metrics (sum across all groups)
std::size_t flop = 0;
std::size_t num_byte = 0;
for(int i = 0; i < problem.group_count_; ++i)
{
flop += std::size_t(2) * problem.Ms_[i] * problem.Ns_[i] * problem.Ks_[i];
num_byte += sizeof(ADataType) * problem.Ms_[i] * problem.Ks_[i] +
sizeof(BDataType) * problem.Ks_[i] * problem.Ns_[i] +
sizeof(CDataType) * problem.Ms_[i] * problem.Ns_[i];
}
// update
kernel_instance.perf_result_.latency_ = avg_time;
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
if(setting_.log_ > 0 && !setting_.json_output_)
{
std::cout << kernel_instance << std::endl;
}
// Copy results back from device and verify per-group
for(int i = 0; i < problem.group_count_; ++i)
{
c_dev_bufs[i]->FromDevice(c_dev_results[i].data());
}
bool verified_correct =
!setting_.verify_ || compare_grouped(name, problem, c_dev_results, c_host_results);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << name << std::endl;
}
// Clear device tensors
for(int i = 0; i < problem.group_count_; ++i)
{
c_dev_bufs[i]->SetZero();
c_dev_results[i].SetZero();
}
}
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
kernel_instances_.end(),
[metric](const auto& a, const auto& b) {
return PerformanceResult::compare(
b.perf_result_, a.perf_result_, metric);
});
if(setting_.json_output_)
{
// Output clean JSON only
std::cout << kernel_instance << std::endl;
}
else
{
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "Current kernel performance is: " << kernel_instance << std::endl;
std::cout << "**********************************" << std::endl;
}
if(!setting_.csv_filename_.empty())
{
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
if(!file.is_open())
{
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
}
else
{
if(file.tellp() == 0)
{
file << "rocm_version,device_name," << "group_count,kbatch,"
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
<< "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
}
const auto& problem = kernel_instance.problem_;
const auto& name = kernel_instance.name_;
const auto& perf = kernel_instance.perf_result_;
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
<< problem.group_count_ << "," << problem.kbatch_ << "," << problem.dtype_a_
<< "," << problem.dtype_b_ << "," << problem.dtype_acc_ << ","
<< problem.dtype_c_ << "," << problem.layout_a_ << "," << problem.layout_b_
<< "," << problem.layout_c_ << "," << name << "," << std::fixed
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
<< "\n";
if(!file)
{
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
}
}
}
return kernel_instance;
}
GroupedGemmProfiler(const GroupedGemmProfiler&) = delete;
GroupedGemmProfiler& operator=(const GroupedGemmProfiler&) = delete;
private:
~GroupedGemmProfiler() { kernel_instances_.clear(); }
GroupedGemmProfiler(Setting setting) : setting_(setting) {}
Setting setting_;
std::vector<KernelInstance> kernel_instances_;
};