Tile engine for streamk (#3157)

* [CK TILE STREAMK] Introduce initial support for tile engine in streamk GEMM.

- This commit lays the groundwork for integrating the tile engine into streamk GEMM.
  It focuses on creating benchmark executables for streamk GEMM.
- Additional scripts like test_benchmark.sh and gemm_benchmark.py will be added once
  the streamk implementation reaches stability.

* [CK TILE STREAMK] Enable CI to execute tile engine benchmarks for StreamK GEMM

* [CK TILE STREAMK] Refactor: Extract common utility functions.

* [CK TILE STREAMK] Revise tile engine of streamk to align with the updated implementation

* Add pre-commit

* [CK TILE STREAMK] Add 'dp_persistent' and 'reduction_strategy' in output of CK TILE STREAMK

* [CK TILE STREAMK] Fix a bug about value of 'dp_persistent' of CK TILE STREAMK

* [CK TILE STREAMK] Update Jenkinsfile

* [CK TILE Engine] Update StreamK tile engine help message

Remove default value messages as they are automatically printed

* [CK TILE Engine] Update StreamK tile engine

- Remove namespace reboot

* [CK TILE Engine] Update StreamK tile engine

- Fix merge error
This commit is contained in:
Cong Ma
2025-11-27 15:49:57 -07:00
committed by GitHub
parent 24d88d2472
commit 30727c48fc
15 changed files with 2530 additions and 19 deletions

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}

View File

@@ -1,3 +1,4 @@
add_subdirectory(gemm)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_preshuffle)
add_subdirectory(gemm_preshuffle)
add_subdirectory(gemm_streamk)

View File

@@ -0,0 +1,295 @@
set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF)
# Store the directory path for use in functions
set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
# Function to create individual GEMM targets
function(create_individual_gemm_target datatype layout trait tile_config config_json)
# Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
return()
endif()
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
# First split by underscore to get three groups
string(REPLACE "_" ";" config_groups ${tile_config})
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
# Parse tile dimensions
string(REPLACE "x" ";" tile_parts ${tile_dims})
list(GET tile_parts 0 tile_m)
list(GET tile_parts 1 tile_n)
list(GET tile_parts 2 tile_k)
# Parse warp dimensions
string(REPLACE "x" ";" warp_parts ${warp_dims})
list(GET warp_parts 0 warp_m)
list(GET warp_parts 1 warp_n)
list(GET warp_parts 2 warp_k)
# Parse warp tile dimensions
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
list(GET warp_tile_parts 0 warp_tile_m)
list(GET warp_tile_parts 1 warp_tile_n)
list(GET warp_tile_parts 2 warp_tile_k)
set(target_name "benchmark_gemm_streamk_${datatype}_${layout}_${trait}_${tile_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
OUTPUT ${instance_header}
COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${config_json}
--gen_single
--kernel_name "${datatype}_${layout}_${trait}_${tile_config}"
--tile_config "${tile_config}"
--trait_combo "${trait}"
DEPENDS ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py ${config_json}
COMMENT "Generating ${instance_header}"
)
# Create the executable
add_executable(${target_name}
${GEMM_SOURCE_DIR}/gemm_streamk_benchmark_single.cpp
${instance_header}
)
# Set GPU architectures
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL})
# Set compile definitions
target_compile_definitions(${target_name} PRIVATE
GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
)
# Include directories
target_include_directories(${target_name} PRIVATE
${GEMM_SOURCE_DIR}
${working_path}
)
# Compile options
target_compile_options(${target_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
-include ${instance_header}
)
# Add to collection targets
add_dependencies(benchmark_gemm_streamk_all ${target_name})
add_dependencies(benchmark_gemm_streamk_${datatype} ${target_name})
add_dependencies(benchmark_gemm_streamk_${layout} ${target_name})
add_dependencies(benchmark_gemm_streamk_${datatype}_${layout} ${target_name})
# Add to trait-specific targets
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 pipeline)
list(GET trait_parts 1 epilogue)
list(GET trait_parts 2 scheduler)
add_dependencies(benchmark_gemm_streamk_${pipeline} ${target_name})
add_dependencies(benchmark_gemm_streamk_${epilogue} ${target_name})
add_dependencies(benchmark_gemm_streamk_${scheduler} ${target_name})
endfunction()
# Function to build individual GEMM targets
function(build_individual_gemm_targets datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Choose config file
# Priority order:
# 1. Environment variable GEMM_CONFIG_FILE
# 2. CMake variable GEMM_CONFIG_FILE
# 3. Default based on layout
# Check environment variable first
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
message(STATUS " Using config from environment variable: ${config_filename}")
elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "")
# Use CMake variable if set
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}")
message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}")
else()
# Use default config for all layouts
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
message(STATUS " Using default config for layout ${layout}")
endif()
# Check if config file exists
if(NOT EXISTS ${json_blob})
message(FATAL_ERROR "Config file not found: ${json_blob}")
endif()
# Determine number of workers for parallel generation
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
else()
# Use processor count but limit to avoid memory issues
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
math(EXPR num_workers "${num_cores}")
if(num_workers GREATER 8)
set(num_workers 8)
endif()
endif()
# Generate individual kernel files using parallel version
message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
message(STATUS " Working path: ${working_path}")
message(STATUS " Config file: ${json_blob}")
message(STATUS " Python executable: ${Python3_EXECUTABLE}")
message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py")
# Create working directory first
file(MAKE_DIRECTORY ${working_path})
# First, just list the kernels (fast operation)
message(STATUS " Listing kernel configurations...")
execute_process(
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--config_json ${json_blob}
--list_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
# Read kernel count
if(EXISTS ${working_path}/gemm_kernel_count.txt)
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
string(STRIP "${kernel_count}" kernel_count)
message(STATUS " Found ${kernel_count} kernel configurations")
else()
message(FATAL_ERROR "Kernel count file not found")
endif()
# Read kernel list and create targets
if(EXISTS ${working_path}/gemm_kernel_list.txt)
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
foreach(line IN LISTS kernel_lines)
# Parse line: kernel_name|tile_config|trait_combo
string(REPLACE "|" ";" parts "${line}")
list(GET parts 0 kernel_name)
list(GET parts 1 tile_config)
list(GET parts 2 trait_combo)
# Create individual target
create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
endforeach()
else()
message(FATAL_ERROR "Kernel list file not found")
endif()
endfunction()
# Main build logic - Only individual builds supported
message(STATUS "=== Starting Tile Engine StreamK GEMM Configuration ===")
message(STATUS "GEMM_STREAMK_DATATYPE: ${GEMM_STREAMK_DATATYPE}")
message(STATUS "GEMM_STREAMK_LAYOUT: ${GEMM_STREAMK_LAYOUT}")
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# Filter GPU targets to only gfx90a, gfx942
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
if(target IN_LIST DESIRED_TARGETS)
list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target})
message(STATUS " Adding GPU target: ${target}")
endif()
endforeach()
# Skip build if no matching targets found
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
else()
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
# Enable parallel compilation optimizations
# Set up job pools for better parallel compilation control
set_property(GLOBAL PROPERTY JOB_POOLS
compile_heavy=4 # Limit heavy compilations to prevent OOM
compile_normal=16 # Allow more parallel normal compilations
)
# Enable compiler cache if available and explicitly requested
# Disabled by default due to permission issues in CI environments
if(ENABLE_CCACHE_GEMM)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
message(STATUS "Using ccache for faster compilation")
else()
message(WARNING "ccache requested but not found")
endif()
else()
message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
endif()
# Create master collection targets
add_custom_target(benchmark_gemm_streamk_all)
# Create datatype collection targets
foreach(dt IN LISTS GEMM_STREAMK_DATATYPE)
add_custom_target(benchmark_gemm_streamk_${dt})
endforeach()
# Create layout collection targets
foreach(l IN LISTS GEMM_STREAMK_LAYOUT)
add_custom_target(benchmark_gemm_streamk_${l})
endforeach()
# Create combined collection targets
foreach(dt IN LISTS GEMM_STREAMK_DATATYPE)
foreach(l IN LISTS GEMM_STREAMK_LAYOUT)
add_custom_target(benchmark_gemm_streamk_${dt}_${l})
endforeach()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM kernels
set(GEMM_PIPELINES "mem;compv3;compv4")
set(GEMM_EPILOGUES "default;cshuffle")
set(GEMM_SCHEDULERS "intrawave;interwave")
foreach(pipeline IN LISTS GEMM_PIPELINES)
add_custom_target(benchmark_gemm_streamk_${pipeline})
endforeach()
foreach(epilogue IN LISTS GEMM_EPILOGUES)
add_custom_target(benchmark_gemm_streamk_${epilogue})
endforeach()
foreach(scheduler IN LISTS GEMM_SCHEDULERS)
add_custom_target(benchmark_gemm_streamk_${scheduler})
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_STREAMK_DATATYPE)
foreach(l IN LISTS GEMM_STREAMK_LAYOUT)
build_individual_gemm_targets(${dt} ${l})
endforeach()
endforeach()
endif()

View File

@@ -0,0 +1,105 @@
{
"problem": {
},
"tile_config": {
"tile_m": {
"max": 256,
"min": 64,
"step": 64
},
"tile_n": {
"max": 256,
"min": 64,
"step": 64
},
"tile_k": {
"max": 256,
"min": 64,
"step": 64
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
},
"persistent": {
"values": [
false, true
]
},
"reduction_strategy": {
"values": [
"reduction", "atomic"
]
}
}
}

View File

@@ -0,0 +1,201 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iomanip>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "gemm_streamk_common.hpp"
#include "utility/validation.hpp"
// Data types and Layouts are defined by the generated kernel headers
// No hardcoded type definitions here to avoid conflicts
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct GemmProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_c_;
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
std::string layout_a_, layout_b_, layout_c_;
bool structured_sparsity_;
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
<< "\n"
<< "}";
return os;
}
};
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
std::string dp_persistent_;
std::string reduction_strategy_;
GemmProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \"" << obj.name_ << "\",\n"
<< " \"dp_persistent\": \"" << obj.dp_persistent_ << "\",\n"
<< " \"reduction_strategy\": \"" << obj.reduction_strategy_ << "\",\n"
<< " \"problem\": " << obj.problem_ << ",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
struct Setting
{
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;
int verify_;
int init_method_;
bool log_;
std::string csv_filename_;
bool flush_cache_;
int rotating_count_;
bool json_output_;
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
void gemm_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
if(verify == 1)
{
c_m_n_host_result.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_result);
}
else if(verify == 2)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
c_m_n_host_result.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
}
}

View File

@@ -0,0 +1,169 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <functional>
#include <tuple>
#include <exception>
#include <sstream>
#include <vector>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "gemm_streamk_profiler.hpp"
#include "gemm_streamk_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_streamk_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "The value for m dimension.")
.insert("n", "4096", "The value for n dimension.")
.insert("k", "2048", "The value for k dimension.")
.insert("stride_a", "0", "The stride value for tensor A.")
.insert("stride_b", "0", "The stride value for tensor B.")
.insert("stride_c", "0", "The stride value for tensor C.")
.insert("split_k", "1", "The split value for k dimension.")
.insert("verify",
"0",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
"false.")
.insert("warmup", "50", "The number of iterations before benchmark the kernel.")
.insert("repeat", "100", "The number of iterations to benchmark the kernel.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1).")
.insert("flush_cache", "true", "To flush cache, possible values are true or false.")
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth.")
.insert("csv_filename",
"",
"The filename of benchmark result. Default is empty (no CSV output).")
.insert("structured_sparsity",
"false",
"Whether use sparsity kernel or not. Possible values are true or false.")
.insert(
"json_output",
"false",
"Whether to output results in JSON format only. Possible values are true or false.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
void repeat_once_if_verify(Setting& setting)
{
// The output buffer will be reset after each run, which means the gemm result will be
// accumulated in the output buffer. So limit the repeat to 1 if verify is true.
if(setting.verify_)
{
setting.n_repeat_ = 1;
setting.n_warmup_ = 0;
}
}
void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;
std::string layout_b = BLayout::name;
std::string layout_c = CLayout::name;
// Create GemmProblem struct
GemmProblem gemm_problem{arg_parser.get_int("split_k"),
arg_parser.get_int("m"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("stride_a"),
arg_parser.get_int("stride_b"),
arg_parser.get_int("stride_c"),
dtype_a,
dtype_b,
dtype_acc,
dtype_c,
layout_a,
layout_b,
layout_c,
arg_parser.get_bool("structured_sparsity")};
// Create Setting struct
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count"),
arg_parser.get_bool("json_output")};
repeat_once_if_verify(setting);
// Get the profiler instance
auto& profiler = GemmProfiler::instance(setting);
try
{
// Create a lambda that wraps the kernel launch
auto kernel_func = [](const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& stream) {
return SelectedKernel::launch(args, stream);
};
// Benchmark the kernel
profiler.benchmark(gemm_problem, kernel_func);
// Select best instance based on metric
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
{
parser.print();
return EXIT_FAILURE;
}
benchmark_gemm_single(parser);
return EXIT_SUCCESS;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,145 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
// DataTypeTraits for all supported types
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
// Helper function to determine if a layout is row-major
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};
// Helper to extract traits from kernel name
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
{
KernelTraits traits;
// Extract pipeline
if(kernel_name.find("compv3") != std::string::npos)
{
traits.pipeline = "compv3";
}
else if(kernel_name.find("compv4") != std::string::npos)
{
traits.pipeline = "compv4";
}
else if(kernel_name.find("mem") != std::string::npos)
{
traits.pipeline = "mem";
}
// Extract scheduler
if(kernel_name.find("interwave") != std::string::npos)
{
traits.scheduler = "interwave";
}
else
{
traits.scheduler = "intrawave";
}
// Extract epilogue
if(kernel_name.find("default") != std::string::npos &&
kernel_name.find("default_") == std::string::npos)
{
traits.epilogue = "default";
}
else
{
traits.epilogue = "cshuffle";
}
// Padding flags would need to be extracted from the kernel configuration
// For now, we'll leave them as false
return traits;
}

View File

@@ -0,0 +1,905 @@
#!/usr/bin/env python
import os
import json
import argparse
import itertools
import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
from typing import Optional
from gemm_streamk_validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
)
logging.basicConfig(level=logging.INFO)
class GemmKernelBuilder:
def __init__(self, working_path, datatype, layout, config_json=None):
self.working_path = Path(working_path)
self.datatype = datatype
self.layout = layout
self.config_json = config_json
# Create working directory if it doesn't exist
self.working_path.mkdir(parents=True, exist_ok=True)
# Load configuration
if config_json and os.path.exists(config_json):
with open(config_json, "r") as f:
self.config = json.load(f)
else:
self.config = self._get_default_config()
def _get_default_config(self):
"""Return default configuration if no config file is provided"""
# Define base tile configurations that work for all layouts
base_fp16_configs = [
{
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 32,
},
{
"tile_m": 256,
"tile_n": 128,
"tile_k": 32,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
},
]
base_fp8_configs = [
{
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 4,
"warp_n": 1,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 32,
},
{
"tile_m": 256,
"tile_n": 128,
"tile_k": 32,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 32,
},
]
# Create configurations for all supported layouts
all_layouts = ["rcr", "rrr", "ccr", "crr"]
tile_configs = {}
for datatype, base_configs in [
("fp16", base_fp16_configs),
("fp8", base_fp8_configs),
]:
tile_configs[datatype] = {}
for layout in all_layouts:
tile_configs[datatype][layout] = base_configs
return {
"tile_configs": tile_configs,
"traits": {
"pipelines": ["mem", "compv3", "compv4"],
"epilogues": ["default", "cshuffle"],
"schedulers": ["intrawave", "interwave"],
},
"structured_sparsity": ["false"],
"padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]},
"persistent": ["false"],
"reduction_strategy": ["reduction"],
}
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations for the current datatype and layout"""
if "tile_configs" in self.config:
# Old format
return (
self.config["tile_configs"].get(self.datatype, {}).get(self.layout, [])
)
elif "tile_config" in self.config:
# New format - generate combinations from individual parameter values
tile_config = self.config["tile_config"]
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
tile_k_values = tile_config.get("tile_k", {}).get("values", [32])
warp_m_values = tile_config.get("warp_m", {}).get("values", [2])
warp_n_values = tile_config.get("warp_n", {}).get("values", [2])
warp_k_values = tile_config.get("warp_k", {}).get("values", [1])
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32])
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32])
warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32])
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
else:
# Fallback to default
return []
def _validate_tile_config(
self,
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline="mem", # Default pipeline for validation
fast_mode=False, # Add fast mode option
):
"""Validate that tile configuration is reasonable"""
if fast_mode:
# Fast validation for listing - only basic sanity checks
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
return False
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
return False
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
return False
# Basic divisibility check
if tile_m % (warp_m * warp_tile_m) != 0:
return False
if tile_n % (warp_n * warp_tile_n) != 0:
return False
if tile_k % (warp_k * warp_tile_k) != 0:
return False
return True
else:
# Full validation for generation
# Determine data types for validation
a_datatype = self.datatype
b_datatype = self.datatype
c_datatype = self.datatype
# Special handling for certain data types
if self.datatype in ["fp8", "bf8"]:
c_datatype = "fp16"
# Use the comprehensive validation function
return is_tile_config_valid(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
a_datatype,
b_datatype,
c_datatype,
pipeline,
)
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
if "trait_config" in self.config:
# New format
trait_config = self.config["trait_config"]
pipelines = trait_config.get("pipeline", {}).get("values", ["mem"])
epilogues = trait_config.get("epilogue", {}).get("values", ["default"])
schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"])
pad_m_values = trait_config.get("pad_m", {}).get("values", [False])
pad_n_values = trait_config.get("pad_n", {}).get("values", [False])
pad_k_values = trait_config.get("pad_k", {}).get("values", [False])
persistent_values = trait_config.get("persistent", {}).get(
"values", [False]
)
reduction_strategy_value = trait_config.get("reduction_strategy", {}).get(
"values", ["reduction"]
)
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
reduction_strategy_value,
pad_m_values,
pad_n_values,
pad_k_values,
persistent_values,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler, reduction_strategy = combo[:4]
if is_trait_combination_valid(
pipeline, epilogue, scheduler, reduction_strategy
):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}-{reduction_strategy}"
)
else:
# Fallback to minimal default
combinations = [
(
"compv3",
"cshuffle",
"intrawave",
"reduction_strategy",
False,
False,
False,
False,
)
]
return combinations
def _get_dtype_string(self):
"""Get C++ type string for datatype"""
dtype_map = {
"fp16": "ck_tile::fp16_t",
"fp8": "ck_tile::fp8_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
return dtype_map.get(self.datatype, "float")
_LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
def _get_abc_layouts(self, layout_code: Optional[str] = None):
"""
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
If layout_code is None, use self.layout.
"""
if layout_code is None:
# fall back to the instance field
layout_code = getattr(self, "layout", "")
code = str(layout_code).strip().lower()
if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code):
raise ValueError(
f"Invalid layout '{layout_code}'. "
"Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)."
)
a_layout = self._LAYOUT_MAP[code[0]]
b_layout = self._LAYOUT_MAP[code[1]]
c_layout = self._LAYOUT_MAP[code[2]]
return a_layout, b_layout, c_layout
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
"""Generate a single kernel instance"""
(
pipeline,
epilogue,
scheduler,
reduction_strategy,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{reduction_strategy}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = (
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
)
tile_str += (
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
)
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
# Map pipeline names to the correct pipeline implementation
pipeline_impl_map = {
"mem": "ck_tile::GemmPipelineAgBgCrMem",
"compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
}
reduction_strategy_map = {
"atomic": "ck_tile::StreamKReductionStrategy::Atomic",
"reduction": "ck_tile::StreamKReductionStrategy::Reduction",
}
# Determine accumulator type based on datatype
acc_type = "float"
if self.datatype in ["int8", "int4"]:
acc_type = "ck_tile::int32_t"
# Determine output type
c_type = self._get_dtype_string()
if self.datatype in ["fp8", "bf8"]:
c_type = "ck_tile::fp16_t"
# Determine layouts based on self.layout
a_layout, b_layout, c_layout = self._get_abc_layouts()
# Generate kernel instance code using the correct API
pragma_line = "#pragma once\n" if is_header else ""
instance_code = f"""// Generated kernel instance for {kernel_name}
{pragma_line}
#include <cstdint>
#include <utility>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
using ADataType = {self._get_dtype_string()};
using BDataType = {self._get_dtype_string()};
using AccDataType = {acc_type};
using CDataType = {c_type};
using ALayout = {a_layout};
using BLayout = {b_layout};
using CLayout = {c_layout};
// Kernel name for display
constexpr const char* KERNEL_NAME = "{kernel_name}";
// Wrapper for simplified launch interface
struct SelectedKernel {{
// Tile configuration
static constexpr ck_tile::index_t BlockSize = 256;
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};
// Traits
static constexpr bool kPadM = {"true" if pad_m == "true" else "false"};
static constexpr bool kPadN = {"true" if pad_n == "true" else "false"};
static constexpr bool kPadK = {"true" if pad_k == "true" else "false"};
static constexpr bool Preshuffle = false;
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
static constexpr int kBlockPerCu = 1;
static constexpr bool StructuredSparsity = false;
static constexpr bool NumWaveGroup = 1;
static constexpr bool TransposeC = false;
static constexpr bool UsePersistentKernel = {"true" if str(persistent).lower() == "true" else "false"};
static constexpr bool UseStructuredSparsity = false;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr ck_tile::StreamKReductionStrategy reduction_strategy = {reduction_strategy_map.get(reduction_strategy, "ck_tile::StreamKReductionStrategy::Reduction")};
// Tile shape
using TileShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileM, TileN, TileK>,
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
// Tile partitioner
using TilePartitioner = ck_tile::StreamKTilePartitioner<TileShape, reduction_strategy, UsePersistentKernel>;
// Traits
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC,
StructuredSparsity,
UsePersistentKernel,
NumWaveGroup,
Preshuffle>;
// Pipeline problem
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
ADataType,
BDataType,
AccDataType,
TileShape,
GemmUniversalTraits>;
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
const auto Run = [&](const auto memory_operation_) {{
constexpr auto memory_operation = memory_operation_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
TileShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}<UniversalGemmProblem>;
// Epilogue
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock, // kM_
TilePartitioner::NPerBlock, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
memory_operation, // MemoryOperation_
NumWaveGroups>; // kNumWaveGroups_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
// Kernel type
using GemmKernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Make kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
const auto workspace_size = GemmKernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
// Get grid and block sizes
const dim3 grids = GemmKernel::GridSize(kargs.tile_partitioner);
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << "\\n"
<< "shape: " << TileShape::GetName() << "\\n"
<< "problem: " << UniversalGemmProblem::GetName() << "\\n"
<< "pipeline: " << GemmPipeline::GetName() << "\\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}}
auto reset_data_buffers = [&]() {{
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
}}
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}}
}};
// Launch kernel
float ave_time = ck_tile::launch_kernel_time_mask(
stream,
reset_data_buffers,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
// ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
// return std::make_tuple(ave_time, num_wgs_per_tile);
}};
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy)
{{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
else // We are using ck_tile::StreamKReductionStrategy::Reduction
{{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}}
}}
}};
"""
return kernel_name, instance_code
def generate_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
num_workers = min(
multiprocessing.cpu_count(), 8
) # Limit to avoid memory issues
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
# Prepare work items for parallel processing
work_items = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
work_items.append(
(
tile_config,
trait_combo,
self.working_path,
self.datatype,
self.layout,
)
)
print(
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
)
print(f" Tile configs: {len(tile_configs)}")
print(f" Trait combinations: {len(trait_combos)}")
print(f" Total kernels: {len(work_items)}")
# Show first few work items for debugging
if work_items:
print(" First work item example:")
tile_config, trait_combo = work_items[0][:2]
print(f" Tile config: {tile_config}")
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
# Process work items in parallel
kernel_list = []
completed = 0
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_workers
) as executor:
# Submit all work items
print(f" Submitting {len(work_items)} tasks to executor...")
future_to_item = {
executor.submit(_generate_single_kernel_individual, item): item
for item in work_items
}
print(" All tasks submitted, waiting for completion...")
# Collect results with progress reporting
for future in concurrent.futures.as_completed(future_to_item):
completed += 1
if completed % 100 == 0 or completed == len(work_items):
print(
f" Progress: {completed}/{len(work_items)} kernels generated"
)
try:
result = future.result()
if result:
kernel_list.append(result)
except Exception as exc:
item = future_to_item[future]
print(f"Kernel generation failed for {item}: {exc}")
# Sort kernel list for consistent ordering
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
# Generate CMake include file for individual targets
self._generate_cmake_individual_targets(kernel_list)
print(
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
)
def _generate_cmake_individual_targets(self, kernel_list):
"""Generate CMake include file that creates individual targets"""
cmake_code = f"""# Generated CMake file for individual GEMM targets
# Datatype: {self.datatype}, Layout: {self.layout}
"""
for kernel_name, trait_combo, tile_config in kernel_list:
pipeline, epilogue, scheduler = trait_combo[:3]
# Format tile config for CMake function
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
str(x) for x in trait_combo[3:]
)
cmake_code += f'create_individual_gemm_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
# Write CMake include file
with open(self.working_path / "gemm_individual_targets.cmake", "w") as f:
f.write(cmake_code)
def write_kernel_list(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
reduction_strategy,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}_{reduction_strategy}"
# Create tile configuration string
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
}
)
# Write kernel count
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
for kernel in kernel_list:
# Format: kernel_name|tile_config|trait_combo
tile_config = kernel["tile_config"]
trait_combo = kernel["trait_combo"]
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = (
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
+ "_".join(str(x) for x in trait_combo[3:])
)
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
print(f"Listed {len(kernel_list)} kernel configurations")
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
# Generate individual kernel files
self.generate_individual(num_workers)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
tile_config, trait_combo, working_path, datatype, layout = work_item
# Create a temporary builder instance for this worker
builder = GemmKernelBuilder(working_path, datatype, layout)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Create simplified filename without the "gemm_" prefix
# Remove "gemm_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_"):
simplified_name = simplified_name[5:] # Remove "gemm_" prefix
# Write individual header file
header_file = working_path / f"gemm_streamk_single_{simplified_name}.hpp"
with open(header_file, "w") as f:
f.write(instance_code)
return (kernel_name, trait_combo, tile_config)
except Exception as e:
print(f"Error generating individual kernel: {e}")
return None
def main():
parser = argparse.ArgumentParser(
description="GEMM kernel instance builder with parallel support"
)
parser.add_argument("--working_path", required=True, help="Working directory path")
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "fp32", "fp64"],
help="Data type",
)
parser.add_argument(
"--layout",
required=True,
choices=["rcr", "rrr", "ccr", "crr"],
help="Matrix layout",
)
parser.add_argument("--config_json", help="Configuration JSON file")
parser.add_argument(
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_individual", action="store_true", help="Generate individual kernel files"
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
)
parser.add_argument("--kernel_name", help="Kernel name for single generation")
parser.add_argument(
"--tile_config", help="Tile configuration string for single generation"
)
parser.add_argument(
"--trait_combo", help="Trait combination string for single generation"
)
parser.add_argument(
"--list_kernels",
action="store_true",
help="List kernel configurations without generating files",
)
args = parser.parse_args()
# Create builder
builder = GemmKernelBuilder(
args.working_path, args.datatype, args.layout, args.config_json
)
if args.list_kernels:
# Fast listing mode - just write kernel list without generating files
builder.write_kernel_list()
elif args.gen_single:
# Generate a single kernel file
if not args.kernel_name or not args.tile_config or not args.trait_combo:
parser.error(
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
)
# Parse tile config
tile_parts = args.tile_config.split("_")
tile_dims = tile_parts[0].split("x")
warp_dims = tile_parts[1].split("x")
warp_tile_dims = tile_parts[2].split("x")
tile_config = {
"tile_m": int(tile_dims[0]),
"tile_n": int(tile_dims[1]),
"tile_k": int(tile_dims[2]),
"warp_m": int(warp_dims[0]),
"warp_n": int(warp_dims[1]),
"warp_k": int(warp_dims[2]),
"warp_tile_m": int(warp_tile_dims[0]),
"warp_tile_n": int(warp_tile_dims[1]),
"warp_tile_k": int(warp_tile_dims[2]),
}
# Parse trait combo
trait_parts = args.trait_combo.split("_")
trait_combo = (
trait_parts[0], # pipeline
trait_parts[1], # epilogue
trait_parts[2], # scheduler
trait_parts[3], # reduction_strategy
trait_parts[4] == "false", # pad_m
trait_parts[5] == "false", # pad_n
trait_parts[6] == "false", # pad_k
trait_parts[7], # persistent
)
# Generate the kernel
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
)
# Write the file
simplified_name = kernel_name
if simplified_name.startswith("gemm_"):
simplified_name = simplified_name[5:]
header_file = (
builder.working_path / f"gemm_streamk_single_{simplified_name}.hpp"
)
with open(header_file, "w") as f:
f.write(instance_code)
print(f"Generated {header_file}")
elif args.gen_individual:
# Generate all individual kernel files
builder.run(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,296 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <fstream>
#include <iomanip>
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "gemm_streamk_benchmark.hpp"
class GemmProfiler
{
public:
static GemmProfiler& instance(Setting setting)
{
static GemmProfiler instance{setting};
return instance;
}
// Overload for single kernel benchmarking
void benchmark(GemmProblem& gemm_problem,
std::function<float(const ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)> kernel_func)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)>>
callables;
callables.push_back(
[kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {
float time = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time);
});
benchmark(gemm_problem, callables);
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
gemm_problem.stride_a_ = ck_tile::get_default_stride(
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a));
gemm_problem.stride_b_ = ck_tile::get_default_stride(
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b));
gemm_problem.stride_c_ = ck_tile::get_default_stride(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
if(setting_.init_method_ == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
else if(setting_.init_method_ == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
}
else if(setting_.init_method_ == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
}
if(gemm_problem.structured_sparsity_)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::StreamKHostArgs gemm_args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
gemm_problem.stride_c_};
ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
if(setting_.verify_)
{
gemm_host_reference(setting_.verify_,
a_m_k,
b_k_n,
c_m_n_host_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
gemm_problem.stride_c_);
}
for(auto& callable : callables)
{
auto kernel_run_result = callable(gemm_args,
ck_tile::stream_config{nullptr,
true,
setting_.log_,
setting_.n_warmup_,
setting_.n_repeat_,
setting_.is_gpu_timer_,
setting_.flush_cache_,
setting_.rotating_count_});
process_result(gemm_problem,
c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
kernel_run_result);
}
}
void process_result(const GemmProblem& gemm_problem,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
auto dp_persistent =
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
auto reduction_strategy =
SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic
? "Atomic"
: "Reduction";
KernelInstance kernel_instance{
name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}};
// compute performance metric
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
// update
kernel_instance.perf_result_.latency_ = avg_time;
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
if(setting_.log_ > 0 && !setting_.json_output_)
{
std::cout << kernel_instance << std::endl;
}
// verify result
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(
name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << name << std::endl;
}
// clear tensor
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
kernel_instances_.end(),
[metric](const auto& a, const auto& b) {
return PerformanceResult::compare(
b.perf_result_, a.perf_result_, metric);
});
if(setting_.json_output_)
{
// Output clean JSON only
std::cout << kernel_instance << std::endl;
}
else
{
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "Current kernel performance is: " << kernel_instance << std::endl;
std::cout << "**********************************" << std::endl;
}
if(!setting_.csv_filename_.empty())
{
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
if(!file.is_open())
{
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
}
else
{
if(file.tellp() == 0)
{
file << "rocm_version,device_name,"
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
<< "structured_sparsity," << "dp_persistent," << "reduction_strategy,"
<< "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
}
const auto& problem = kernel_instance.problem_;
const auto& name = kernel_instance.name_;
const auto& dp_persistent = kernel_instance.dp_persistent_;
const auto& reduction_strategy = kernel_instance.reduction_strategy_;
const auto& perf = kernel_instance.perf_result_;
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
<< "," << problem.structured_sparsity_ << "," << dp_persistent << ","
<< reduction_strategy << "," << name << "," << std::fixed
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
<< "\n";
if(!file)
{
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
}
}
}
return kernel_instance;
}
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler& operator=(const GemmProfiler&) = delete;
private:
~GemmProfiler() { kernel_instances_.clear(); }
GemmProfiler(Setting setting) : setting_(setting) {}
Setting setting_;
std::vector<KernelInstance> kernel_instances_;
};

View File

@@ -0,0 +1,350 @@
#!/usr/bin/env python
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
"""
Validation utilities for GEMM kernel generation.
Extracted from tile_engine_develop for consistency.
"""
import subprocess
import re
from functools import lru_cache
import logging
from typing import Tuple, List
# Element size mapping for different data types
ELEMENT_SIZE_MAP = {
"fp16": 2,
"bf16": 2,
"int8": 1,
"fp8": 1,
"bf8": 1,
"int4": 0.5,
"int32": 4,
"fp32": 4,
"fp64": 8,
}
# Supported warp tile combinations for different GPU architectures and data types
WARP_TILE_SUPPORTED_COMBINATIONS = {
"gfx90a": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
}
# Unsupported trait combinations
TRAIT_UNSUPPORTED_COMBINATIONS = {
("compv3", "cshuffle", "interwave", "reduction"),
("compv3", "default", "interwave", "reduction"),
("compv3", "cshuffle", "interwave", "atomic"),
("compv3", "default", "interwave", "atomic"),
}
def element_size(data_type: str) -> float:
"""Calculate the size (in bytes) of a single element for given data type."""
data_type = data_type.lower()
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
logging.debug("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
logging.debug("GPU query timeout (5s)")
except Exception as e:
logging.debug(f"GPU detection error: {str(e)}")
return ""
def is_trait_combination_valid(
pipeline: str, epilogue: str, scheduler: str, reduction_strategy: str
) -> bool:
"""Check if a trait combination is valid."""
return (
pipeline,
epilogue,
scheduler,
reduction_strategy,
) not in TRAIT_UNSUPPORTED_COMBINATIONS
def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool:
"""Validate warp configuration."""
return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
def validate_dimension_alignment(
tile_m: int,
tile_n: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
) -> Tuple[bool, List[str]]:
"""Check if tile dimensions are properly aligned with warp dimensions."""
alignment_issues = []
if tile_m % (warp_m * warp_tile_m) != 0:
alignment_issues.append(
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
)
if tile_n % (warp_n * warp_tile_n) != 0:
alignment_issues.append(
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
)
if tile_k % (warp_k * warp_tile_k) != 0:
alignment_issues.append(
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
)
return len(alignment_issues) == 0, alignment_issues
def validate_lds_capacity(
tile_m: int,
tile_n: int,
tile_k: int,
a_datatype: str,
b_datatype: str,
pipeline: str,
) -> Tuple[bool, str]:
"""Validate LDS capacity requirements."""
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
if total_tile_in_lds > max_tile_size:
error_msg = (
f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
)
return False, error_msg
return True, ""
def validate_warp_tile_combination(
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
gpu_name: str = None,
) -> Tuple[bool, str]:
"""Validate warp tile combination against GPU-specific supported combinations."""
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)
# Construct the key for looking up supported combinations
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
# Check if we have GPU-specific combinations
gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {})
if not gpu_warp_tile_combinations:
# If GPU not recognized, try to be permissive but log warning
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
return True, ""
# Check if we have combinations for this data type combination
allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, [])
if not allowed_combinations:
# For data type combinations not in the list, be permissive
logging.debug(
f"No warp tile combinations found for data types: {warp_tile_key}"
)
return True, ""
# Check if current combination is in the allowed list
if current_combination not in allowed_combinations:
error_msg = (
f"Invalid warp tile combination: {current_combination} not in allowed list. "
f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}"
)
return False, error_msg
return True, ""
def is_tile_config_valid(
tile_m: int,
tile_n: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
warp_tile_m: int,
warp_tile_n: int,
warp_tile_k: int,
a_datatype: str,
b_datatype: str,
c_datatype: str,
pipeline: str,
trait_name: str = None,
) -> bool:
"""
Comprehensive tile configuration validation.
Returns True if configuration is valid, False otherwise.
"""
# Basic sanity checks
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
return False
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
return False
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
return False
# Check that warp tiles fit within block tiles
if warp_m * warp_tile_m > tile_m:
return False
if warp_n * warp_tile_n > tile_n:
return False
if warp_k * warp_tile_k > tile_k:
return False
# Validate warp configuration
if not validate_warp_configuration(warp_m, warp_n, warp_k):
logging.debug(
f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})"
)
return False
# Validate dimension alignment
is_aligned, alignment_issues = validate_dimension_alignment(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
)
if not is_aligned:
logging.debug(
f"Dimension alignment failed: {', '.join(alignment_issues)}. "
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
)
return False
# Validate LDS capacity
lds_valid, lds_error = validate_lds_capacity(
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline
)
if not lds_valid:
logging.debug(f"LDS validation failed: {lds_error}")
return False
# Validate warp tile combination
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
)
if not warp_tile_valid:
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
return False
return True