mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4996 (commit 0a47fbe)
[CK TILE ENGINE] Add grouped_gemm operator to Tile Engine (gfx942/gfx950) (#4996) ## Motivation The grouped_gemm CK Tile kernel exists (e.g., `example/17_grouped_gemm/`) but has no Tile Engine wrapper. Grouped GEMM handles multiple independent GEMM problems with varying M/N/K dimensions in a single kernel launch. This PR adds the Tile Engine infrastructure for automated kernel generation, benchmarking, and profiling of grouped GEMM kernels. Jira: AICK-809 ## Technical Details - Created Tile Engine wrapper under `tile_engine/ops/gemm/grouped_gemm/` following the `gemm_universal` template - Files added: `CMakeLists.txt`, `grouped_gemm_common.hpp`, `grouped_gemm_benchmark.hpp`, `grouped_gemm_profiler.hpp`, `grouped_gemm_benchmark.py`, `grouped_gemm_benchmark_single.cpp`, `grouped_gemm_instance_builder.py`, `configs/` - Supported datatypes: fp16, fp8, bf16, bf8 - Supported layouts: rcr, rrr, ccr, crr - Target GPUs: gfx942, gfx950 - CK Tile kernel: `ck_tile::GroupedGemmKernel` from `include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp` - Instance builder extends `GemmKernelBuilder` base class - Registered in `tile_engine/ops/gemm/CMakeLists.txt` - Updated Jenkinsfile to build and benchmark grouped_gemm targets in CI - Benchmark infrastructure includes JSON output, CSV export, and verification support ## Test Plan - CMake configure succeeds for grouped_gemm targets - Kernel instance builder generates valid kernel headers for all (datatype, layout) combinations - At least one kernel binary compiles and runs per datatype/layout combination - Correctness passes with `--verify 1` on gfx942/gfx950 ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
9f47b8a63d
commit
c85c272c39
@@ -3,4 +3,5 @@
|
||||
|
||||
add_subdirectory(gemm_universal)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_preshuffle)
|
||||
add_subdirectory(gemm_preshuffle)
|
||||
add_subdirectory(grouped_gemm)
|
||||
@@ -168,6 +168,8 @@ class GemmKernelBuilder:
|
||||
default_pipeline = "compv4"
|
||||
elif self.kernel_name_prefix == "gemm_preshuffle":
|
||||
default_pipeline = "preshufflev2"
|
||||
elif self.kernel_name_prefix == "grouped_gemm":
|
||||
default_pipeline = "compv4"
|
||||
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
@@ -335,7 +337,11 @@ class GemmKernelBuilder:
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
if self.kernel_name_prefix in ["gemm_universal", "gemm_multi_d"]:
|
||||
if self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_multi_d",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
# Map pipeline names to the correct pipeline implementation
|
||||
pipeline_impl_map = {
|
||||
"mem": "ck_tile::GemmPipelineAgBgCrMem",
|
||||
@@ -410,6 +416,11 @@ class GemmKernelBuilder:
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
"""
|
||||
if self.kernel_name_prefix == "grouped_gemm":
|
||||
instance_code += """#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
"""
|
||||
return instance_code
|
||||
|
||||
@@ -425,10 +436,11 @@ class GemmKernelBuilder:
|
||||
# Assign layouts based on self.layout
|
||||
if self.kernel_name_prefix == "gemm_multi_d":
|
||||
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
|
||||
elif (
|
||||
self.kernel_name_prefix == "gemm_universal"
|
||||
or self.kernel_name_prefix == "gemm_preshuffle"
|
||||
):
|
||||
elif self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_preshuffle",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
|
||||
|
||||
instance_code = f"""
|
||||
@@ -502,8 +514,12 @@ struct SelectedKernel {{
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2"] else "false"};"""
|
||||
|
||||
if self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
instance_code += f"""
|
||||
if self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_preshuffle",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
instance_code += f"""
|
||||
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;"""
|
||||
@@ -528,9 +544,13 @@ struct SelectedKernel {{
|
||||
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;"""
|
||||
|
||||
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
elif self.kernel_name_prefix in [
|
||||
"gemm_universal",
|
||||
"gemm_preshuffle",
|
||||
"grouped_gemm",
|
||||
]:
|
||||
instance_code = """
|
||||
|
||||
|
||||
// Tile shape
|
||||
using TileShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<TileM, TileN, TileK>,
|
||||
@@ -604,6 +624,13 @@ struct SelectedKernel {{
|
||||
|
||||
// Launch function
|
||||
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {"""
|
||||
elif self.kernel_name_prefix == "grouped_gemm":
|
||||
instance_code = """
|
||||
|
||||
// Launch function
|
||||
static float launch(const std::vector<ck_tile::GroupedGemmHostArgs<>>& gemm_descs,
|
||||
const ck_tile::stream_config& stream,
|
||||
void* kargs_ptr) {"""
|
||||
|
||||
# Scheduler initialization
|
||||
if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]:
|
||||
@@ -644,12 +671,12 @@ struct SelectedKernel {{
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
|
||||
|
||||
# Scheduler initialization
|
||||
if self.kernel_name_prefix in ["gemm_universal"]:
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += f"""
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""
|
||||
|
||||
# UniversalGemmProblem
|
||||
if self.kernel_name_prefix in ["gemm_universal"]:
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += """
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
@@ -664,7 +691,7 @@ struct SelectedKernel {{
|
||||
scheduler>;"""
|
||||
|
||||
# GemmPipeline
|
||||
if self.kernel_name_prefix in ["gemm_universal"]:
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += f"""
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
|
||||
@@ -711,13 +738,13 @@ struct SelectedKernel {{
|
||||
|
||||
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
instance_code += f"""
|
||||
|
||||
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
@@ -725,7 +752,7 @@ struct SelectedKernel {{
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
||||
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
@@ -733,13 +760,55 @@ struct SelectedKernel {{
|
||||
<< std::endl;
|
||||
}}"""
|
||||
|
||||
instance_code += f"""
|
||||
instance_code += f"""
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
elif self.kernel_name_prefix == "grouped_gemm":
|
||||
instance_code += f"""
|
||||
|
||||
// Kernel type
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping grouped gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"Kernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "dim3(kargs.empty() ? 0 : kargs.back().block_end, 1, 1)"};
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
kargs.size()));
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
@@ -753,14 +822,14 @@ struct SelectedKernel {{
|
||||
"""
|
||||
|
||||
if epilogue == "cshuffle":
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += self.populate_cshuffle_gemm_universal()
|
||||
elif self.kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += self.populate_cshuffle_gemm_multi_d()
|
||||
elif self.kernel_name_prefix == "gemm_preshuffle":
|
||||
instance_code += self.populate_cshuffle_gemm_preshuffle()
|
||||
else: # default epilogue
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
||||
instance_code += self.populate_default_gemm_universal()
|
||||
elif self.kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += self.populate_default_gemm_multi_d()
|
||||
|
||||
309
tile_engine/ops/gemm/grouped_gemm/CMakeLists.txt
Normal file
309
tile_engine/ops/gemm/grouped_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(GROUPED_GEMM_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for Grouped GEMM (semicolon-separated)")
|
||||
set(GROUPED_GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for Grouped GEMM (semicolon-separated)")
|
||||
set(GROUPED_GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GROUPED_GEMM "Enable ccache for Grouped GEMM ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
set(GROUPED_GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Function to create individual Grouped GEMM targets
|
||||
function(create_individual_grouped_gemm_target datatype layout trait tile_config config_json)
|
||||
# Use the parent scope GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL variable
|
||||
if(NOT GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping individual Grouped GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
|
||||
# First split by underscore to get three groups
|
||||
string(REPLACE "_" ";" config_groups ${tile_config})
|
||||
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
|
||||
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
|
||||
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
|
||||
|
||||
# Parse tile dimensions
|
||||
string(REPLACE "x" ";" tile_parts ${tile_dims})
|
||||
list(GET tile_parts 0 tile_m)
|
||||
list(GET tile_parts 1 tile_n)
|
||||
list(GET tile_parts 2 tile_k)
|
||||
|
||||
# Parse warp dimensions
|
||||
string(REPLACE "x" ";" warp_parts ${warp_dims})
|
||||
list(GET warp_parts 0 warp_m)
|
||||
list(GET warp_parts 1 warp_n)
|
||||
list(GET warp_parts 2 warp_k)
|
||||
|
||||
# Parse warp tile dimensions
|
||||
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
|
||||
list(GET warp_tile_parts 0 warp_tile_m)
|
||||
list(GET warp_tile_parts 1 warp_tile_n)
|
||||
list(GET warp_tile_parts 2 warp_tile_k)
|
||||
|
||||
set(target_name "benchmark_grouped_gemm_${datatype}_${layout}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Generate the single instance header for this kernel
|
||||
set(instance_header "${working_path}/grouped_gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${GROUPED_GEMM_SOURCE_DIR}/grouped_gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "grouped_gemm_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
--gpu_target "${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}"
|
||||
DEPENDS ${GROUPED_GEMM_SOURCE_DIR}/grouped_gemm_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
EXCLUDE_FROM_ALL
|
||||
${GROUPED_GEMM_SOURCE_DIR}/grouped_gemm_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL})
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GROUPED_GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${GROUPED_GEMM_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_grouped_gemm_all ${target_name})
|
||||
add_dependencies(benchmark_grouped_gemm_${datatype} ${target_name})
|
||||
add_dependencies(benchmark_grouped_gemm_${layout} ${target_name})
|
||||
add_dependencies(benchmark_grouped_gemm_${datatype}_${layout} ${target_name})
|
||||
|
||||
# Add to trait-specific targets
|
||||
string(REPLACE "_" ";" trait_parts ${trait})
|
||||
list(GET trait_parts 0 pipeline)
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
add_dependencies(benchmark_grouped_gemm_${pipeline}_pipeline ${target_name})
|
||||
add_dependencies(benchmark_grouped_gemm_${epilogue}_epilogue ${target_name})
|
||||
add_dependencies(benchmark_grouped_gemm_${scheduler}_scheduler ${target_name})
|
||||
endfunction()
|
||||
|
||||
# Function to build individual Grouped GEMM targets
|
||||
function(build_individual_grouped_gemm_targets datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Choose config file
|
||||
# Priority order:
|
||||
# 1. Environment variable GROUPED_GEMM_CONFIG_FILE
|
||||
# 2. CMake variable GROUPED_GEMM_CONFIG_FILE
|
||||
# 3. Default based on layout
|
||||
|
||||
# Check environment variable first
|
||||
if(DEFINED ENV{GROUPED_GEMM_CONFIG_FILE} AND NOT "$ENV{GROUPED_GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GROUPED_GEMM_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(VERBOSE " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GROUPED_GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GROUPED_GEMM_CONFIG_FILE}")
|
||||
message(VERBOSE " Using custom config: ${GROUPED_GEMM_CONFIG_FILE}")
|
||||
else()
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(VERBOSE " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# Check if config file exists
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
# Determine number of workers for parallel generation
|
||||
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
else()
|
||||
# Use processor count but limit to avoid memory issues
|
||||
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
|
||||
math(EXPR num_workers "${num_cores}")
|
||||
if(num_workers GREATER 8)
|
||||
set(num_workers 8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(VERBOSE " Working path: ${working_path}")
|
||||
message(VERBOSE " Config file: ${json_blob}")
|
||||
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/grouped_gemm_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/grouped_gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/grouped_gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/grouped_gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/grouped_gemm_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(VERBOSE " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
|
||||
# Read kernel list and create targets
|
||||
if(EXISTS ${working_path}/grouped_gemm_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/grouped_gemm_kernel_list.txt kernel_lines)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse line: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Create individual target
|
||||
create_individual_grouped_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endforeach()
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel list file not found")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Main build logic - Only individual builds supported
|
||||
message(VERBOSE "=== Starting Tile Engine Grouped GEMM Configuration ===")
|
||||
message(VERBOSE "GROUPED_GEMM_DATATYPE: ${GROUPED_GEMM_DATATYPE}")
|
||||
message(VERBOSE "GROUPED_GEMM_LAYOUT: ${GROUPED_GEMM_LAYOUT}")
|
||||
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx942, gfx950
|
||||
set(GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx942;gfx950")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(VERBOSE " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine Grouped GEMM build: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual Grouped GEMM targets for GPU targets: ${GROUPED_GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
if(ENABLE_CCACHE_GROUPED_GEMM)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(VERBOSE "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(VERBOSE "ccache disabled for Grouped GEMM ops (use -DENABLE_CCACHE_GROUPED_GEMM=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
add_custom_target(benchmark_grouped_gemm_all)
|
||||
|
||||
# Create datatype collection targets
|
||||
foreach(dt IN LISTS GROUPED_GEMM_DATATYPE)
|
||||
add_custom_target(benchmark_grouped_gemm_${dt})
|
||||
endforeach()
|
||||
|
||||
# Create layout collection targets
|
||||
foreach(l IN LISTS GROUPED_GEMM_LAYOUT)
|
||||
add_custom_target(benchmark_grouped_gemm_${l})
|
||||
endforeach()
|
||||
|
||||
# Create combined collection targets
|
||||
foreach(dt IN LISTS GROUPED_GEMM_DATATYPE)
|
||||
foreach(l IN LISTS GROUPED_GEMM_LAYOUT)
|
||||
add_custom_target(benchmark_grouped_gemm_${dt}_${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# Create trait-based collection targets
|
||||
# These are common trait components used across all Grouped GEMM kernels
|
||||
set(GROUPED_GEMM_PIPELINES "mem;compv3;compv4")
|
||||
set(GROUPED_GEMM_EPILOGUES "default;cshuffle")
|
||||
set(GROUPED_GEMM_SCHEDULERS "intrawave;interwave")
|
||||
|
||||
foreach(pipeline IN LISTS GROUPED_GEMM_PIPELINES)
|
||||
add_custom_target(benchmark_grouped_gemm_${pipeline}_pipeline)
|
||||
endforeach()
|
||||
|
||||
foreach(epilogue IN LISTS GROUPED_GEMM_EPILOGUES)
|
||||
add_custom_target(benchmark_grouped_gemm_${epilogue}_epilogue)
|
||||
endforeach()
|
||||
|
||||
foreach(scheduler IN LISTS GROUPED_GEMM_SCHEDULERS)
|
||||
add_custom_target(benchmark_grouped_gemm_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GROUPED_GEMM_DATATYPE)
|
||||
foreach(l IN LISTS GROUPED_GEMM_LAYOUT)
|
||||
build_individual_grouped_gemm_targets(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
@@ -0,0 +1,92 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 256
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 256
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 256
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1
|
||||
}
|
||||
104
tile_engine/ops/gemm/grouped_gemm/configs/default_config.json
Normal file
104
tile_engine/ops/gemm/grouped_gemm/configs/default_config.json
Normal file
@@ -0,0 +1,104 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 2
|
||||
}
|
||||
266
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.hpp
Normal file
266
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.hpp
Normal file
@@ -0,0 +1,266 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm_common.hpp"
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
|
||||
// Data types and Layouts are defined by the generated kernel headers
|
||||
// No hardcoded type definitions here to avoid conflicts
|
||||
|
||||
enum class Metric
|
||||
{
|
||||
LATENCY = 0,
|
||||
TFLOPS = 1,
|
||||
BANDWIDTH = 2
|
||||
};
|
||||
|
||||
inline constexpr auto get_metric_name(Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return "latency";
|
||||
case Metric::TFLOPS: return "tflops";
|
||||
case Metric::BANDWIDTH: return "bandwidth";
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
struct GroupedGemmProblem
|
||||
{
|
||||
int group_count_;
|
||||
int kbatch_;
|
||||
std::vector<int> Ms_, Ns_, Ks_;
|
||||
std::vector<int> stride_As_, stride_Bs_, stride_Cs_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
|
||||
std::string layout_a_, layout_b_, layout_c_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GroupedGemmProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"group_count\":" << problem.group_count_ << ",\n"
|
||||
<< " \"kbatch\":" << problem.kbatch_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_c\":\"" << problem.layout_c_ << "\"\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
double tflops_;
|
||||
double bandwidth_;
|
||||
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return a.latency_ < b.latency_;
|
||||
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
||||
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
std::string name_;
|
||||
GroupedGemmProblem problem_;
|
||||
PerformanceResult perf_result_;
|
||||
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << obj.name_ << "\",\n"
|
||||
<< " \"problem\": " << obj.problem_ << ",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
bool json_output_;
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
if(version_file.is_open())
|
||||
{
|
||||
std::string version;
|
||||
std::getline(version_file, version);
|
||||
return version;
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations for a single group
|
||||
bool compare_single(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// @brief Function to compare grouped gemm results across all groups
|
||||
bool compare_grouped(std::string instanceName,
|
||||
const GroupedGemmProblem& problem,
|
||||
std::vector<ck_tile::HostTensor<CDataType>>& c_dev_results,
|
||||
std::vector<ck_tile::HostTensor<CDataType>>& c_host_results)
|
||||
{
|
||||
bool pass = true;
|
||||
for(int i = 0; i < problem.group_count_; ++i)
|
||||
{
|
||||
pass &= compare_single(instanceName + "[" + std::to_string(i) + "]",
|
||||
problem.Ks_[i],
|
||||
problem.kbatch_,
|
||||
c_dev_results[i],
|
||||
c_host_results[i]);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU for all groups
|
||||
void gemm_host_reference_grouped(int verify,
|
||||
const GroupedGemmProblem& problem,
|
||||
std::vector<ck_tile::HostTensor<ADataType>>& a_tensors,
|
||||
std::vector<ck_tile::HostTensor<BDataType>>& b_tensors,
|
||||
std::vector<ck_tile::HostTensor<CDataType>>& c_host_results,
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>>& a_dev_bufs,
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>>& b_dev_bufs)
|
||||
{
|
||||
const int group_count = problem.group_count_;
|
||||
|
||||
if(verify == 1)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
c_host_results[i].SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_tensors[i], b_tensors[i], c_host_results[i]);
|
||||
}
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
b_dev_bufs[i]->ToDevice(b_tensors[i].data());
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem c_gpu_buf_ref(c_host_results[i].get_element_space_size_in_bytes());
|
||||
c_host_results[i].SetZero();
|
||||
c_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_dev_bufs[i]->GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_dev_bufs[i]->GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A,
|
||||
d_B,
|
||||
d_C,
|
||||
problem.Ms_[i],
|
||||
problem.Ns_[i],
|
||||
problem.Ks_[i],
|
||||
problem.stride_As_[i],
|
||||
problem.stride_Bs_[i],
|
||||
problem.stride_Cs_[i]);
|
||||
|
||||
c_gpu_buf_ref.FromDevice(c_host_results[i].data());
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma clang diagnostic pop
|
||||
703
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.py
Normal file
703
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_benchmark.py
Normal file
@@ -0,0 +1,703 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
import argparse
|
||||
import csv
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
|
||||
class GroupedGemmBenchmark:
|
||||
def __init__(self, build_dir: str, verbose: bool = False):
|
||||
self.build_dir = Path(build_dir)
|
||||
self.verbose = verbose
|
||||
self.results = []
|
||||
|
||||
def discover_kernels(self) -> List[Path]:
|
||||
"""Find all benchmark_grouped_gemm_* executables in the build directory"""
|
||||
bin_dir = self.build_dir / "bin"
|
||||
if not bin_dir.exists():
|
||||
print(f"Error: Binary directory {bin_dir} does not exist")
|
||||
return []
|
||||
|
||||
kernels = list(bin_dir.glob("benchmark_grouped_gemm_*"))
|
||||
if self.verbose:
|
||||
print(f"Found {len(kernels)} kernel executables")
|
||||
for k in kernels:
|
||||
print(f" - {k.name}")
|
||||
return kernels
|
||||
|
||||
def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]:
|
||||
"""Extract comprehensive kernel information from filename"""
|
||||
name = kernel_path.stem
|
||||
|
||||
# Initialize with basic info
|
||||
info = {
|
||||
"executable": str(kernel_path),
|
||||
"name": name,
|
||||
"data_type": "unknown",
|
||||
"layout": "unknown",
|
||||
"pipeline": "unknown",
|
||||
"scheduler": "unknown",
|
||||
"epilogue": "unknown",
|
||||
}
|
||||
|
||||
# Parse the kernel name pattern:
|
||||
# benchmark_grouped_gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_True_64x64x64_2x2x1_16x16x32
|
||||
parts = name.split("_")
|
||||
|
||||
if len(parts) >= 4:
|
||||
# Extract data type (4th part after benchmark_grouped_gemm_)
|
||||
info["data_type"] = parts[3] if len(parts) > 3 else "unknown"
|
||||
|
||||
# Extract layout (5th part)
|
||||
info["layout"] = parts[4] if len(parts) > 4 else "unknown"
|
||||
|
||||
# Extract pipeline (6th part)
|
||||
info["pipeline"] = parts[5] if len(parts) > 5 else "unknown"
|
||||
|
||||
# Extract epilogue (7th part)
|
||||
info["epilogue"] = parts[6] if len(parts) > 6 else "unknown"
|
||||
|
||||
# Extract scheduler (8th part)
|
||||
info["scheduler"] = parts[7] if len(parts) > 7 else "unknown"
|
||||
|
||||
# Extract detailed configuration from the end of the name
|
||||
config_info = self.parse_detailed_config(name)
|
||||
info.update(config_info)
|
||||
|
||||
# Generate config ID
|
||||
info["config_id"] = self.generate_config_id(info)
|
||||
|
||||
return info
|
||||
|
||||
def parse_detailed_config(self, kernel_name: str) -> Dict:
|
||||
"""Parse detailed configuration from kernel name"""
|
||||
config = {
|
||||
"tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0},
|
||||
"warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0},
|
||||
"warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0},
|
||||
"optimization_flags": {
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Split by underscore and look for patterns
|
||||
parts = kernel_name.split("_")
|
||||
|
||||
# Look for boolean flags (sequence of True/False values)
|
||||
bool_sequence = []
|
||||
for i, part in enumerate(parts):
|
||||
if part in ["True", "False"]:
|
||||
bool_sequence.append(part == "True")
|
||||
# Continue collecting consecutive boolean values
|
||||
j = i + 1
|
||||
while j < len(parts) and parts[j] in ["True", "False"]:
|
||||
bool_sequence.append(parts[j] == "True")
|
||||
j += 1
|
||||
break
|
||||
|
||||
# Assign boolean flags if we found them
|
||||
# Order: pad_m, pad_n, pad_k, persistent (4 flags total)
|
||||
if len(bool_sequence) >= 4:
|
||||
config["optimization_flags"]["pad_m"] = bool_sequence[0]
|
||||
config["optimization_flags"]["pad_n"] = bool_sequence[1]
|
||||
config["optimization_flags"]["pad_k"] = bool_sequence[2]
|
||||
config["optimization_flags"]["persistent"] = bool_sequence[3]
|
||||
|
||||
# Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16)
|
||||
# The pattern is: tile_sizes_warp_config_warp_tile
|
||||
dimension_groups = []
|
||||
for part in parts:
|
||||
if "x" in part and len(part.split("x")) == 3:
|
||||
try:
|
||||
dims = [int(x) for x in part.split("x")]
|
||||
if all(d > 0 for d in dims):
|
||||
dimension_groups.append(dims)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Assign dimensions based on order and magnitude
|
||||
if len(dimension_groups) >= 3:
|
||||
# Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile
|
||||
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
|
||||
|
||||
# Largest dimensions = tile sizes
|
||||
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
||||
|
||||
# Smallest dimensions = warp config
|
||||
config["warp_config"]["warp_m"] = sorted_groups[2][0]
|
||||
config["warp_config"]["warp_n"] = sorted_groups[2][1]
|
||||
config["warp_config"]["warp_k"] = sorted_groups[2][2]
|
||||
|
||||
# Middle dimensions = warp tile
|
||||
config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0]
|
||||
config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1]
|
||||
config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2]
|
||||
elif len(dimension_groups) == 2:
|
||||
# If only 2 groups, assign based on magnitude
|
||||
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
|
||||
|
||||
# Larger = tile sizes
|
||||
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
||||
|
||||
# Smaller = warp config
|
||||
config["warp_config"]["warp_m"] = sorted_groups[1][0]
|
||||
config["warp_config"]["warp_n"] = sorted_groups[1][1]
|
||||
config["warp_config"]["warp_k"] = sorted_groups[1][2]
|
||||
elif len(dimension_groups) == 1:
|
||||
# Only one group - assume it's tile sizes
|
||||
config["tile_sizes"]["tile_m"] = dimension_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = dimension_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = dimension_groups[0][2]
|
||||
|
||||
return config
|
||||
|
||||
def generate_config_id(self, info: Dict) -> str:
|
||||
"""Generate a compact config ID from kernel info"""
|
||||
# Create a compact identifier
|
||||
parts = [
|
||||
info.get("data_type", "unk"),
|
||||
info.get("layout", "unk"),
|
||||
info.get("pipeline", "unk"),
|
||||
info.get("scheduler", "unk"),
|
||||
]
|
||||
|
||||
# Add tile configuration if available
|
||||
tile_sizes = info.get("tile_sizes", {})
|
||||
if tile_sizes.get("tile_m", 0) > 0:
|
||||
tile_str = (
|
||||
f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}"
|
||||
)
|
||||
parts.append(tile_str)
|
||||
|
||||
# Add warp config if available
|
||||
warp_config = info.get("warp_config", {})
|
||||
if warp_config.get("warp_m", 0) > 0:
|
||||
warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}"
|
||||
parts.append(warp_str)
|
||||
|
||||
# Add warp tile if available
|
||||
warp_tile = info.get("warp_tile", {})
|
||||
if warp_tile.get("warp_tile_m", 0) > 0:
|
||||
warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}"
|
||||
parts.append(warp_tile_str)
|
||||
|
||||
return "_".join(parts)
|
||||
|
||||
def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]:
|
||||
"""Run a single kernel with given parameters and save output to individual JSON file"""
|
||||
# Create results directory
|
||||
results_dir = self.build_dir / "results"
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate unique JSON filename for this kernel
|
||||
json_file = results_dir / f"{kernel_path.stem}.json"
|
||||
|
||||
cmd = [str(kernel_path)]
|
||||
|
||||
# Add parameters
|
||||
for key, value in params.items():
|
||||
cmd.append(f"-{key}={value}")
|
||||
|
||||
# Add JSON output flag for clean JSON output
|
||||
cmd.append("-json_output=true")
|
||||
|
||||
if self.verbose:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error running {kernel_path.name}: {result.stderr}")
|
||||
return None
|
||||
|
||||
# Save raw output to individual JSON file
|
||||
output = result.stdout.strip()
|
||||
if output:
|
||||
with open(json_file, "w") as f:
|
||||
f.write(output)
|
||||
|
||||
# Parse the JSON file
|
||||
return self.parse_json_file(json_file)
|
||||
else:
|
||||
print(f"No output from {kernel_path.name}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"Timeout running {kernel_path.name}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error running {kernel_path.name}: {e}")
|
||||
return None
|
||||
|
||||
def parse_json_file(self, json_file: Path) -> Optional[Dict]:
|
||||
"""Parse JSON data from individual kernel output file"""
|
||||
try:
|
||||
with open(json_file, "r") as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# Parse the JSON directly since executables produce clean JSON
|
||||
data = json.loads(content)
|
||||
|
||||
# Return the complete JSON data as-is, just add some convenience fields
|
||||
result = data.copy()
|
||||
if "perf_result" in data:
|
||||
perf = data["perf_result"]
|
||||
# Add convenience fields for backward compatibility
|
||||
result["time_ms"] = perf.get("latency(ms)", 0)
|
||||
result["tflops"] = perf.get("tflops(TFlops)", 0)
|
||||
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
if self.verbose:
|
||||
print(f"Failed to parse JSON from {json_file}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
return None
|
||||
|
||||
def benchmark_problem_size(
|
||||
self,
|
||||
kernels: List[Path],
|
||||
group_count: int = 8,
|
||||
m: int = 3840,
|
||||
n: int = 4096,
|
||||
k: int = 2048,
|
||||
kbatch: int = 1,
|
||||
verify: int = 0,
|
||||
warmup: int = 50,
|
||||
repeat: int = 100,
|
||||
flush_cache: bool = True,
|
||||
rotating_count: int = 1000,
|
||||
) -> List[Dict]:
|
||||
"""Benchmark all kernels for a specific problem size"""
|
||||
results = []
|
||||
|
||||
params = {
|
||||
"group_count": group_count,
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"kbatch": kbatch,
|
||||
"verify": verify,
|
||||
"warmup": warmup,
|
||||
"repeat": repeat,
|
||||
"flush_cache": str(flush_cache).lower(),
|
||||
"rotating_count": rotating_count,
|
||||
}
|
||||
|
||||
print(
|
||||
f"\nBenchmarking group_count={group_count}, M={m}, N={n}, K={k}, kbatch={kbatch}"
|
||||
)
|
||||
|
||||
for kernel_path in kernels:
|
||||
kernel_info = self.extract_kernel_info(kernel_path)
|
||||
result = self.run_kernel(kernel_path, params)
|
||||
|
||||
if result:
|
||||
# Create new structured result format
|
||||
structured_result = {
|
||||
"name": kernel_info["name"],
|
||||
"config_id": kernel_info["config_id"],
|
||||
"problem": result.get("problem", {}),
|
||||
"perf_result": result.get("perf_result", {}),
|
||||
"config": {
|
||||
"data_type": kernel_info["data_type"],
|
||||
"layout": kernel_info["layout"],
|
||||
"pipeline": kernel_info["pipeline"],
|
||||
"scheduler": kernel_info["scheduler"],
|
||||
"epilogue": kernel_info["epilogue"],
|
||||
"tile_sizes": kernel_info.get("tile_sizes", {}),
|
||||
"warp_config": kernel_info.get("warp_config", {}),
|
||||
"warp_tile": kernel_info.get("warp_tile", {}),
|
||||
"optimization_flags": kernel_info.get("optimization_flags", {}),
|
||||
},
|
||||
"executable": kernel_info["executable"],
|
||||
# Keep backward compatibility fields
|
||||
"time_ms": result.get("time_ms", 0),
|
||||
"tflops": result.get("tflops", 0),
|
||||
"bandwidth_gb_s": result.get("bandwidth_gb_s", 0),
|
||||
}
|
||||
|
||||
results.append(structured_result)
|
||||
|
||||
if self.verbose:
|
||||
print(
|
||||
f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def find_best_kernel(
|
||||
self, results: List[Dict], metric: str = "tflops"
|
||||
) -> Optional[Dict]:
|
||||
"""Find the best performing kernel based on metric"""
|
||||
if not results:
|
||||
return None
|
||||
|
||||
if metric == "tflops":
|
||||
return max(results, key=lambda x: x.get("tflops", 0))
|
||||
elif metric == "time_ms":
|
||||
return min(results, key=lambda x: x.get("time_ms", float("inf")))
|
||||
elif metric == "bandwidth_gb_s":
|
||||
return max(results, key=lambda x: x.get("bandwidth_gb_s", 0))
|
||||
else:
|
||||
raise ValueError(f"Unknown metric: {metric}")
|
||||
|
||||
def benchmark_sweep(
|
||||
self,
|
||||
problem_sizes: List[Tuple[int, int, int]],
|
||||
group_counts: List[int] = [8],
|
||||
kbatch_values: List[int] = [1],
|
||||
verify: bool = False,
|
||||
warmup: int = 50,
|
||||
repeat: int = 100,
|
||||
flush_cache: bool = True,
|
||||
rotating_count: int = 1000,
|
||||
) -> Dict:
|
||||
"""Run comprehensive benchmark sweep"""
|
||||
kernels = self.discover_kernels()
|
||||
if not kernels:
|
||||
print("No kernels found!")
|
||||
return {}
|
||||
|
||||
all_results = []
|
||||
best_kernels = {}
|
||||
|
||||
for m, n, k in problem_sizes:
|
||||
for group_count in group_counts:
|
||||
for kbatch in kbatch_values:
|
||||
results = self.benchmark_problem_size(
|
||||
kernels,
|
||||
group_count=group_count,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
kbatch=kbatch,
|
||||
verify=1 if verify else 0,
|
||||
warmup=warmup,
|
||||
repeat=repeat,
|
||||
flush_cache=flush_cache,
|
||||
rotating_count=rotating_count,
|
||||
)
|
||||
|
||||
all_results.extend(results)
|
||||
|
||||
# Find best kernel for this configuration
|
||||
best = self.find_best_kernel(results)
|
||||
if best:
|
||||
key = f"g{group_count}_m{m}_n{n}_k{k}_kbatch{kbatch}"
|
||||
best_kernels[key] = best
|
||||
print(
|
||||
f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)"
|
||||
)
|
||||
|
||||
self.results = all_results
|
||||
return best_kernels
|
||||
|
||||
def export_csv(self, filename: str):
|
||||
"""Export all results to CSV"""
|
||||
if not self.results:
|
||||
print("No results to export")
|
||||
return
|
||||
|
||||
# Get all unique keys from results
|
||||
all_keys = set()
|
||||
for result in self.results:
|
||||
all_keys.update(result.keys())
|
||||
|
||||
# Sort keys for consistent output
|
||||
fieldnames = sorted(all_keys)
|
||||
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(self.results)
|
||||
|
||||
print(f"Results exported to {filename}")
|
||||
|
||||
def export_best_kernels(self, best_kernels: Dict, filename: str):
|
||||
"""Export best kernel selections to file"""
|
||||
with open(filename, "w") as f:
|
||||
f.write("# Best kernel selections for grouped GEMM\n")
|
||||
f.write(
|
||||
"# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n"
|
||||
)
|
||||
|
||||
for key, kernel in sorted(best_kernels.items()):
|
||||
f.write(
|
||||
f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n"
|
||||
)
|
||||
|
||||
print(f"Best kernels exported to {filename}")
|
||||
|
||||
def export_json(self, filename: str, best_kernels: Dict = None):
|
||||
"""Export all results and best kernels to JSON with comprehensive metadata"""
|
||||
from datetime import datetime
|
||||
|
||||
# Calculate comprehensive summary statistics for all metrics
|
||||
successful_results = [r for r in self.results if r.get("tflops", 0) > 0]
|
||||
|
||||
tflops_values = [r.get("tflops", 0) for r in successful_results]
|
||||
bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results]
|
||||
latency_values = [
|
||||
r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0
|
||||
]
|
||||
|
||||
# Performance breakdown by kernel type
|
||||
pipeline_stats = {}
|
||||
scheduler_stats = {}
|
||||
data_type_stats = {}
|
||||
|
||||
for result in successful_results:
|
||||
# Get config info from the new structure
|
||||
config = result.get("config", {})
|
||||
|
||||
# Pipeline statistics
|
||||
pipeline = config.get("pipeline", "unknown")
|
||||
if pipeline not in pipeline_stats:
|
||||
pipeline_stats[pipeline] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
pipeline_stats[pipeline]["count"] += 1
|
||||
pipeline_stats[pipeline]["best_tflops"] = max(
|
||||
pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Scheduler statistics
|
||||
scheduler = config.get("scheduler", "unknown")
|
||||
if scheduler not in scheduler_stats:
|
||||
scheduler_stats[scheduler] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
scheduler_stats[scheduler]["count"] += 1
|
||||
scheduler_stats[scheduler]["best_tflops"] = max(
|
||||
scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Data type statistics
|
||||
data_type = config.get("data_type", "unknown")
|
||||
if data_type not in data_type_stats:
|
||||
data_type_stats[data_type] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
data_type_stats[data_type]["count"] += 1
|
||||
data_type_stats[data_type]["best_tflops"] = max(
|
||||
data_type_stats[data_type]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Calculate averages for breakdown stats
|
||||
for stats_dict, field_name in [
|
||||
(pipeline_stats, "pipeline"),
|
||||
(scheduler_stats, "scheduler"),
|
||||
(data_type_stats, "data_type"),
|
||||
]:
|
||||
for key in stats_dict:
|
||||
relevant_results = [
|
||||
r
|
||||
for r in successful_results
|
||||
if r.get("config", {}).get(field_name, "unknown") == key
|
||||
]
|
||||
if relevant_results:
|
||||
stats_dict[key]["avg_tflops"] = sum(
|
||||
r.get("tflops", 0) for r in relevant_results
|
||||
) / len(relevant_results)
|
||||
|
||||
output_data = {
|
||||
"benchmark_metadata": {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_kernels_tested": len(self.results),
|
||||
"unique_kernels": len(
|
||||
set(r.get("name", "unknown") for r in self.results)
|
||||
),
|
||||
"successful_runs": len(successful_results),
|
||||
"failed_runs": len(self.results) - len(successful_results),
|
||||
},
|
||||
"performance_summary": {
|
||||
"tflops_stats": {
|
||||
"best": max(tflops_values, default=0),
|
||||
"average": sum(tflops_values) / len(tflops_values)
|
||||
if tflops_values
|
||||
else 0,
|
||||
"min": min(tflops_values, default=0),
|
||||
"median": sorted(tflops_values)[len(tflops_values) // 2]
|
||||
if tflops_values
|
||||
else 0,
|
||||
},
|
||||
"bandwidth_stats": {
|
||||
"best_gb_s": max(bandwidth_values, default=0),
|
||||
"average_gb_s": sum(bandwidth_values) / len(bandwidth_values)
|
||||
if bandwidth_values
|
||||
else 0,
|
||||
"min_gb_s": min(bandwidth_values, default=0),
|
||||
"median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2]
|
||||
if bandwidth_values
|
||||
else 0,
|
||||
},
|
||||
"latency_stats": {
|
||||
"best_ms": min(latency_values, default=0),
|
||||
"average_ms": sum(latency_values) / len(latency_values)
|
||||
if latency_values
|
||||
else 0,
|
||||
"max_ms": max(latency_values, default=0),
|
||||
"median_ms": sorted(latency_values)[len(latency_values) // 2]
|
||||
if latency_values
|
||||
else 0,
|
||||
},
|
||||
"kernel_type_breakdown": {
|
||||
"by_pipeline": pipeline_stats,
|
||||
"by_scheduler": scheduler_stats,
|
||||
"by_data_type": data_type_stats,
|
||||
},
|
||||
"total_problem_configurations": len(best_kernels)
|
||||
if best_kernels
|
||||
else 0,
|
||||
},
|
||||
"kernel_results": self.results,
|
||||
"best_kernels_by_problem": best_kernels or {},
|
||||
}
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
print(f"JSON results exported to {filename}")
|
||||
print(f" - Total kernels: {len(self.results)}")
|
||||
print(f" - Successful runs: {len(successful_results)}")
|
||||
print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}")
|
||||
print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s")
|
||||
print(f" - Best latency: {min(latency_values, default=0):.2f}ms")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Grouped GEMM Kernel Benchmarking Tool"
|
||||
)
|
||||
parser.add_argument(
|
||||
"build_dir", help="Build directory containing kernel executables"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--problem-sizes",
|
||||
nargs="+",
|
||||
default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"],
|
||||
help="Default problem sizes as M,N,K tuples (used for all groups)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-counts",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[8],
|
||||
help="Group count values to test (default: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kbatch",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="K-batch (SplitK) values to test",
|
||||
)
|
||||
parser.add_argument("--verify", action="store_true", help="Enable verification")
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
default="grouped_gemm_benchmark_results.csv",
|
||||
help="CSV output filename",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best",
|
||||
default="best_grouped_gemm_kernels.txt",
|
||||
help="Best kernels output filename",
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
||||
parser.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of warmup iterations (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable cache flushing (default: True)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rotating-count",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to rotate cache (default: 1000)",
|
||||
)
|
||||
parser.add_argument("--json", help="JSON output filename (optional)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse problem sizes
|
||||
problem_sizes = []
|
||||
for size_str in args.problem_sizes:
|
||||
try:
|
||||
m, n, k = map(int, size_str.split(","))
|
||||
problem_sizes.append((m, n, k))
|
||||
except ValueError:
|
||||
print(f"Invalid problem size: {size_str}")
|
||||
return 1
|
||||
|
||||
# Create benchmark instance
|
||||
benchmark = GroupedGemmBenchmark(args.build_dir, verbose=args.verbose)
|
||||
|
||||
# Run benchmark sweep
|
||||
print("Starting Grouped GEMM kernel benchmark sweep...")
|
||||
start_time = time.time()
|
||||
|
||||
best_kernels = benchmark.benchmark_sweep(
|
||||
problem_sizes=problem_sizes,
|
||||
group_counts=args.group_counts,
|
||||
kbatch_values=args.kbatch,
|
||||
verify=args.verify,
|
||||
warmup=args.warmup,
|
||||
repeat=args.repeat,
|
||||
flush_cache=args.flush_cache,
|
||||
rotating_count=args.rotating_count,
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"\nBenchmark completed in {elapsed_time:.2f} seconds")
|
||||
|
||||
# Export results
|
||||
benchmark.export_csv(args.csv)
|
||||
benchmark.export_best_kernels(best_kernels, args.best)
|
||||
|
||||
# Export JSON if requested
|
||||
if args.json:
|
||||
benchmark.export_json(args.json, best_kernels)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,195 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm_profiler.hpp"
|
||||
#include "grouped_gemm_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in grouped_gemm_common.hpp
|
||||
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser
|
||||
.insert("m", "3840", "Default M for all groups (if Ms not specified). Default is 3840.")
|
||||
.insert("n", "4096", "Default N for all groups (if Ns not specified). Default is 4096.")
|
||||
.insert("k", "2048", "Default K for all groups (if Ks not specified). Default is 2048.")
|
||||
.insert("Ms", "", "Comma-separated M dimensions per group.")
|
||||
.insert("Ns", "", "Comma-separated N dimensions per group.")
|
||||
.insert("Ks", "", "Comma-separated K dimensions per group.")
|
||||
.insert("stride_As", "", "Comma-separated stride values for tensor A per group.")
|
||||
.insert("stride_Bs", "", "Comma-separated stride values for tensor B per group.")
|
||||
.insert("stride_Cs", "", "Comma-separated stride values for tensor C per group.")
|
||||
.insert("group_count", "8", "Number of groups. Default is 8.")
|
||||
.insert("kbatch", "1", "SplitK batch count. Default is 1.")
|
||||
.insert("verify",
|
||||
"2",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, "
|
||||
"2 for validation on GPU. Default is 2, GPU validation.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Whether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert(
|
||||
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
|
||||
.insert(
|
||||
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Whether if the timer is gpu timer or not. Possible values are false or true. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"true",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is true.")
|
||||
.insert(
|
||||
"rotating_count", "1000", "number of iterations to rotate the cache. default is 1000.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 0, latency.")
|
||||
.insert("csv_filename",
|
||||
"",
|
||||
"The filename of benchmark result. Default is empty (no CSV output).")
|
||||
.insert("json_output",
|
||||
"false",
|
||||
"Whether to output results in JSON format only. Possible values are true or false. "
|
||||
"Default is "
|
||||
"false");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
std::string layout_b = BLayout::name;
|
||||
std::string layout_c = CLayout::name;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
|
||||
// Parse per-group dimensions
|
||||
std::vector<int> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<int> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<int> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<int> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<int> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<int> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
|
||||
// If Ms/Ns/Ks not provided or wrong size, use -m/-n/-k defaults for all groups
|
||||
const auto gc_size = static_cast<std::size_t>(group_count);
|
||||
|
||||
if(group_count == 0 || Ms.size() != gc_size || Ns.size() != gc_size || Ks.size() != gc_size)
|
||||
{
|
||||
const int default_m = arg_parser.get_int("m");
|
||||
const int default_n = arg_parser.get_int("n");
|
||||
const int default_k = arg_parser.get_int("k");
|
||||
|
||||
Ms.assign(group_count, default_m);
|
||||
Ns.assign(group_count, default_n);
|
||||
Ks.assign(group_count, default_k);
|
||||
}
|
||||
|
||||
// Default stride vectors to 0 independently if missing or wrong size
|
||||
if(stride_As.size() != gc_size)
|
||||
stride_As.assign(group_count, 0);
|
||||
if(stride_Bs.size() != gc_size)
|
||||
stride_Bs.assign(group_count, 0);
|
||||
if(stride_Cs.size() != gc_size)
|
||||
stride_Cs.assign(group_count, 0);
|
||||
|
||||
// Create GroupedGemmProblem struct
|
||||
GroupedGemmProblem problem{group_count,
|
||||
kbatch,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
stride_As,
|
||||
stride_Bs,
|
||||
stride_Cs,
|
||||
dtype_a,
|
||||
dtype_b,
|
||||
dtype_acc,
|
||||
dtype_c,
|
||||
layout_a,
|
||||
layout_b,
|
||||
layout_c};
|
||||
|
||||
// Create Setting struct
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count"),
|
||||
arg_parser.get_bool("json_output")};
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = GroupedGemmProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
// Create a lambda that wraps the kernel launch
|
||||
auto kernel_func = [](const std::vector<ck_tile::GroupedGemmHostArgs<>>& descs,
|
||||
const ck_tile::stream_config& stream,
|
||||
void* kargs_ptr) {
|
||||
return SelectedKernel::launch(descs, stream, kargs_ptr);
|
||||
};
|
||||
|
||||
// Benchmark the kernel
|
||||
profiler.benchmark(problem, kernel_func);
|
||||
|
||||
// Select best instance based on metric
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
benchmark_single(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
100
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_common.hpp
Normal file
100
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_common.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // compv3, compv4, mem
|
||||
std::string scheduler; // intrawave, interwave
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("compv3"),
|
||||
scheduler("intrawave"),
|
||||
epilogue("cshuffle"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,303 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import importlib.util
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def _import_gemm_kernel_builder():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"gemm_instance_builder",
|
||||
os.path.join(parent_dir, "gemm_instance_builder.py"),
|
||||
)
|
||||
gemm_builder_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(gemm_builder_module)
|
||||
|
||||
return gemm_builder_module.GemmKernelBuilder
|
||||
|
||||
|
||||
GemmKernelBuilder = _import_gemm_kernel_builder()
|
||||
|
||||
|
||||
class GroupedGemmKernelBuilder(GemmKernelBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
|
||||
def _generate_all_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
self.kernel_name_prefix,
|
||||
self.working_path,
|
||||
self.gpu_target,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GroupedGemmKernelBuilder(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
# Create simplified filename without the "grouped_gemm_" prefix
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("grouped_gemm_"):
|
||||
simplified_name = simplified_name[
|
||||
len(kernel_name_prefix) + 1 :
|
||||
] # Remove "grouped_gemm" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"grouped_gemm_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Grouped GEMM kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcr", "rrr", "ccr", "crr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument("--config_json", help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.datatype not in ["fp16", "bf16", "fp8", "bf8"]:
|
||||
parser.error(
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
|
||||
)
|
||||
|
||||
layout_str = args.layout.lower()
|
||||
if len(layout_str) != 3:
|
||||
parser.error(
|
||||
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
|
||||
matrix_a_layout = layout_str[0]
|
||||
matrix_b_layout = layout_str[1]
|
||||
matrix_c_layout = layout_str[2]
|
||||
|
||||
if matrix_a_layout not in ["r", "c"] or matrix_b_layout not in ["r", "c"]:
|
||||
parser.error(
|
||||
f"Invalid matrix_a layout : {matrix_a_layout} or matrix_b layout: {matrix_b_layout} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
|
||||
)
|
||||
|
||||
if matrix_c_layout != "r":
|
||||
parser.error(
|
||||
f"Invalid matrix_c layout: {matrix_c_layout} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
kernel_name_prefix = "grouped_gemm"
|
||||
builder = GroupedGemmKernelBuilder(
|
||||
kernel_name_prefix,
|
||||
args.working_path,
|
||||
args.gpu_target,
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.config_json,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
builder._list_kernels()
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file input validation
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3] == "True", # pad_m
|
||||
trait_parts[4] == "True", # pad_n
|
||||
trait_parts[5] == "True", # pad_k
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
# Generate the kernel
|
||||
builder._generate_kernel_instance(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
)
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder._generate_all_individual(args.num_workers)
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
326
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_profiler.hpp
Normal file
326
tile_engine/ops/gemm/grouped_gemm/grouped_gemm_profiler.hpp
Normal file
@@ -0,0 +1,326 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "grouped_gemm_benchmark.hpp"
|
||||
|
||||
class GroupedGemmProfiler
|
||||
{
|
||||
public:
|
||||
static GroupedGemmProfiler& instance(Setting setting)
|
||||
{
|
||||
static GroupedGemmProfiler instance{setting};
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Overload for single kernel benchmarking
|
||||
void benchmark(GroupedGemmProblem& problem,
|
||||
std::function<float(const std::vector<ck_tile::GroupedGemmHostArgs<>>&,
|
||||
const ck_tile::stream_config&,
|
||||
void*)> kernel_func)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
std::vector<ck_tile::GroupedGemmHostArgs<>>&, const ck_tile::stream_config&, void*)>>
|
||||
callables;
|
||||
|
||||
callables.push_back([kernel_func](std::vector<ck_tile::GroupedGemmHostArgs<>>& descs,
|
||||
const ck_tile::stream_config& stream,
|
||||
void* kargs_ptr) {
|
||||
float time = kernel_func(descs, stream, kargs_ptr);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
});
|
||||
|
||||
benchmark(problem, callables);
|
||||
}
|
||||
|
||||
void benchmark(
|
||||
GroupedGemmProblem& problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
std::vector<ck_tile::GroupedGemmHostArgs<>>&, const ck_tile::stream_config&, void*)>>&
|
||||
callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
const int group_count = problem.group_count_;
|
||||
|
||||
// Compute default strides for each group
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
problem.stride_As_[i] = ck_tile::get_default_stride(
|
||||
problem.Ms_[i], problem.Ks_[i], problem.stride_As_[i], is_row_major(layout_a));
|
||||
problem.stride_Bs_[i] = ck_tile::get_default_stride(
|
||||
problem.Ks_[i], problem.Ns_[i], problem.stride_Bs_[i], is_row_major(layout_b));
|
||||
problem.stride_Cs_[i] = ck_tile::get_default_stride(
|
||||
problem.Ms_[i], problem.Ns_[i], problem.stride_Cs_[i], is_row_major(layout_c));
|
||||
}
|
||||
|
||||
// Create per-group tensors
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_dev_results;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
c_dev_results.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_dev_bufs;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_dev_bufs;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_dev_bufs;
|
||||
|
||||
a_dev_bufs.reserve(group_count);
|
||||
b_dev_bufs.reserve(group_count);
|
||||
c_dev_bufs.reserve(group_count);
|
||||
|
||||
std::vector<ck_tile::GroupedGemmHostArgs<>> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
const ck_tile::index_t M = problem.Ms_[i];
|
||||
const ck_tile::index_t N = problem.Ns_[i];
|
||||
const ck_tile::index_t K = problem.Ks_[i];
|
||||
|
||||
a_tensors.push_back(ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
M, K, problem.stride_As_[i], is_row_major(layout_a))));
|
||||
b_tensors.push_back(ck_tile::HostTensor<BDataType>(ck_tile::host_tensor_descriptor(
|
||||
K, N, problem.stride_Bs_[i], is_row_major(layout_b))));
|
||||
c_dev_results.push_back(ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
M, N, problem.stride_Cs_[i], is_row_major(layout_c))));
|
||||
|
||||
if(setting_.init_method_ == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_tensors[i]);
|
||||
}
|
||||
else if(setting_.init_method_ == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_tensors[i]);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_tensors[i]);
|
||||
}
|
||||
else if(setting_.init_method_ == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_tensors[i]);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_tensors[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_tensors[i].SetZero();
|
||||
b_tensors[i].SetZero();
|
||||
}
|
||||
|
||||
a_dev_bufs.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_dev_bufs.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_dev_bufs.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_dev_results[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_dev_bufs[i]->ToDevice(a_tensors[i].data());
|
||||
b_dev_bufs[i]->ToDevice(b_tensors[i].data());
|
||||
c_dev_bufs[i]->SetZero();
|
||||
c_dev_results[i].SetZero();
|
||||
|
||||
const void* p_a = a_dev_bufs[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_dev_bufs[i]->GetDeviceBuffer();
|
||||
void* p_c = c_dev_bufs[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
{/*ds_ptr*/},
|
||||
p_c,
|
||||
problem.kbatch_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
problem.stride_As_[i],
|
||||
problem.stride_Bs_[i],
|
||||
{/*stride_Ds*/},
|
||||
problem.stride_Cs_[i]});
|
||||
}
|
||||
|
||||
// Allocate workspace for kernel args
|
||||
ck_tile::DeviceMem workspace(gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>));
|
||||
|
||||
// Compute host reference for verification
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_host_results;
|
||||
if(setting_.verify_)
|
||||
{
|
||||
c_host_results.reserve(group_count);
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
c_host_results.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(problem.Ms_[i],
|
||||
problem.Ns_[i],
|
||||
problem.stride_Cs_[i],
|
||||
is_row_major(layout_c))));
|
||||
}
|
||||
gemm_host_reference_grouped(setting_.verify_,
|
||||
problem,
|
||||
a_tensors,
|
||||
b_tensors,
|
||||
c_host_results,
|
||||
a_dev_bufs,
|
||||
b_dev_bufs);
|
||||
}
|
||||
|
||||
for(auto& callable : callables)
|
||||
{
|
||||
auto kernel_run_result = callable(gemm_descs,
|
||||
ck_tile::stream_config{nullptr,
|
||||
true,
|
||||
setting_.log_,
|
||||
setting_.n_warmup_,
|
||||
setting_.n_repeat_,
|
||||
setting_.is_gpu_timer_,
|
||||
setting_.flush_cache_,
|
||||
setting_.rotating_count_},
|
||||
workspace.GetDeviceBuffer());
|
||||
process_result(problem, c_dev_bufs, c_host_results, c_dev_results, kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GroupedGemmProblem& problem,
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>>& c_dev_bufs,
|
||||
std::vector<ck_tile::HostTensor<CDataType>>& c_host_results,
|
||||
std::vector<ck_tile::HostTensor<CDataType>>& c_dev_results,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
// Compute performance metrics (sum across all groups)
|
||||
std::size_t flop = 0;
|
||||
std::size_t num_byte = 0;
|
||||
for(int i = 0; i < problem.group_count_; ++i)
|
||||
{
|
||||
flop += std::size_t(2) * problem.Ms_[i] * problem.Ns_[i] * problem.Ks_[i];
|
||||
num_byte += sizeof(ADataType) * problem.Ms_[i] * problem.Ks_[i] +
|
||||
sizeof(BDataType) * problem.Ks_[i] * problem.Ns_[i] +
|
||||
sizeof(CDataType) * problem.Ms_[i] * problem.Ns_[i];
|
||||
}
|
||||
|
||||
// update
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0 && !setting_.json_output_)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
// Copy results back from device and verify per-group
|
||||
for(int i = 0; i < problem.group_count_; ++i)
|
||||
{
|
||||
c_dev_bufs[i]->FromDevice(c_dev_results[i].data());
|
||||
}
|
||||
|
||||
bool verified_correct =
|
||||
!setting_.verify_ || compare_grouped(name, problem, c_dev_results, c_host_results);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
// Clear device tensors
|
||||
for(int i = 0; i < problem.group_count_; ++i)
|
||||
{
|
||||
c_dev_bufs[i]->SetZero();
|
||||
c_dev_results[i].SetZero();
|
||||
}
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
|
||||
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
||||
kernel_instances_.end(),
|
||||
[metric](const auto& a, const auto& b) {
|
||||
return PerformanceResult::compare(
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
if(setting_.json_output_)
|
||||
{
|
||||
// Output clean JSON only
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
}
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file.tellp() == 0)
|
||||
{
|
||||
file << "rocm_version,device_name," << "group_count,kbatch,"
|
||||
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
||||
<< "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
||||
}
|
||||
|
||||
const auto& problem = kernel_instance.problem_;
|
||||
const auto& name = kernel_instance.name_;
|
||||
const auto& perf = kernel_instance.perf_result_;
|
||||
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.group_count_ << "," << problem.kbatch_ << "," << problem.dtype_a_
|
||||
<< "," << problem.dtype_b_ << "," << problem.dtype_acc_ << ","
|
||||
<< problem.dtype_c_ << "," << problem.layout_a_ << "," << problem.layout_b_
|
||||
<< "," << problem.layout_c_ << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GroupedGemmProfiler(const GroupedGemmProfiler&) = delete;
|
||||
GroupedGemmProfiler& operator=(const GroupedGemmProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GroupedGemmProfiler() { kernel_instances_.clear(); }
|
||||
GroupedGemmProfiler(Setting setting) : setting_(setting) {}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
Reference in New Issue
Block a user