GEMM Multi D for CK Tile Engine (#2660)

* Readme for GEMM Multi D

* GEMM Multi D partial Progress

* GEMM Multi D partial Progress!

* CK Tile Engine GEMM Multi D : All Python files generated

* Partial Progress

* Partial Progress

* Partial Progress

* Partial Progress : Incorrect Result

* Partial Progress : Debugging

* Partial Progress : Correct Results

* Partial Progress - Incorrect Results

* Partial Progress - Commenting Passthrough bypass logic

* Changing Passthrough to MultiplyMultiply

* Correct Results!

* Fix and debug the pass through feature

* Sample commit

* Correct Results : MultiplyMultiply

* Code Cleanup

* Removing Failed Instances

* Working code before Unary element support

* Custom Elementwise Function support and working implementation for Mul and Add

* Updating README

* Working for Passthrough

* Review Comments : Minor Fixes

* Review Comments : Minor Fixes

* Readme Updated

* Partial Changes after Rebase

* Working Code : Changes after Rebase

* Updating Jenkins file

* Removing default value changed while testing

* Configuration changes in config files

* Tile Handler changes in GEMM Multi D Tile Engine

* Tile Handler changes in GEMM Multi D Example

* Change log for Gemm Multi D in CK Tile Engine

* Configuration changes in config files

---------

Co-authored-by: ThomasNing <thomasning@amd.com>
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-08-12 18:05:05 -05:00
committed by GitHub
parent 30dafe8281
commit 3f57ec3d2d
18 changed files with 2547 additions and 291 deletions

View File

@@ -0,0 +1,152 @@
set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)")
set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(semicolon-separated)")
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
function(build_gemm_multi_d_for_datatype_layout datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Comment this if-else block when using user_provided_config
if(layout STREQUAL "rcrr")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
else()
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json")
endif()
# uncomment this if you want to use user_provided_config.json
# set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
# Generate kernel list
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${json_blob}
--list_blobs
RESULT_VARIABLE ret
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}")
endif()
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs.txt" codegen_blobs)
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs_range.txt" codegen_blobs_range)
# Generate the blobs
add_custom_command(
OUTPUT ${codegen_blobs}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
--working_path "${working_path}"
--datatype ${datatype}
--layout ${layout}
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json "${json_blob}"
--gen_blobs
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
)
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
set(intermediate_libs)
list(LENGTH codegen_blobs codegen_blobs_len)
foreach(blob IN LISTS codegen_blobs_range)
string(STRIP "${blob}" stripped_blob)
separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}")
# Each line is: <trait_name> <first_index_inclusive> <last_index_exclusive>
list(GET spilit_blob 0 name)
list(GET spilit_blob 1 first)
list(GET spilit_blob 2 last)
math(EXPR total_files "${last} - ${first}")
if(total_files EQUAL 0)
continue() # nothing for this trait
endif()
# Object libraries (chunked) per trait
set(sub_intermediate_libs)
set(chunk_size 3)
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")
math(EXPR num_chunks_minus_1 "${num_chunks} - 1")
foreach(i RANGE 0 ${num_chunks_minus_1})
math(EXPR start "${first} + ${i} * ${chunk_size} ")
math(EXPR end "${start} + ${chunk_size} - 1")
set(chunk_files)
foreach(j RANGE ${start} ${end})
if(j LESS ${last} AND j LESS ${codegen_blobs_len})
list(GET codegen_blobs ${j} f)
list(APPEND chunk_files "${f}")
endif()
endforeach()
#list(LENGTH chunk_files chunk_files_len)
#if(chunk_files_len AND chunk_files_len GREATER 1)
if(chunk_files)
set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}")
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
endif()
endforeach()
# ------------------ Bundle the object libs into one static lib ---------
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
if(sub_intermediate_libs)
set(intermediate_lib_name "gemm_multi_d_staticlib_${name}_${datatype}_${layout}")
# Collect the $<TARGET_OBJECTS:...> expressions
set(obj_exprs)
foreach(objlib IN LISTS sub_intermediate_libs)
list(APPEND obj_exprs $<TARGET_OBJECTS:${objlib}>)
endforeach()
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout})
#foreach(objlib IN LISTS sub_intermediate_libs)
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
#endforeach()
list(APPEND intermediate_libs ${intermediate_lib_name})
endif()
endforeach()
# Interface library for instances
add_library(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE)
add_dependencies(gemm_multi_d_template_instances_${datatype}_${layout} gemm_multi_d_gen_${datatype}_${layout})
target_link_libraries(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs})
target_include_directories(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
set_target_properties(gemm_multi_d_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX)
# Host API interface library
add_library(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE)
target_link_libraries(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE gemm_multi_d_template_instances_${datatype}_${layout})
target_include_directories(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
# Executable per datatype
set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}")
add_executable(${exec_name} benchmark_gemm_multi_d.cpp)
target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout})
target_compile_options(${exec_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
endfunction()
# Process each datatype in isolation
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
build_gemm_multi_d_for_datatype_layout(${dt} ${l})
endforeach()
endforeach()

View File

@@ -0,0 +1,110 @@
CK Tile Engine for GEMM Multi D is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues while able to give custom datatype and Layout selections
# Kernel Configurations
# User Specific
Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time.
For reference please see `./configs/user_provided_config.json`.
# Default
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
## Build Instructions
``` bash
# in the root of composable kernel create build directory
mkdir build && cd build
# build composable kernel
# replace [Arch] with the appropriate architecture or leave blank and
# replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16])
# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr])
# replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default.
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul"
# generate different executable for each passed datatype
make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j
make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j
```
`benchmark_gemm_multi_d_[Datatype]_[Layout]` will be located in the `./bin/` directory.
`benchmark_gemm_multi_d_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified.
``` bash
rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # rebuild
```
## For eaxmple build for gfx942 for datatype with rcr layout
``` bash
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr"
make benchmark_gemm_multi_d_fp16_rcrr -j
## benchmark_gemm inputs
```
-m The value for m dimension. Default is 3840.
-n The value for n dimension. Default is 4096.
-k The value for k dimension. Default is 2048.
-stride_a The stride value for tensor A. Default is 0.
-stride_b The stride value for tensor B. Default is 0.
-stride_ds The stride value for tensor Ds. Default is 0.
-stride_e The stride value for tensor E. Default is 0.
-split_k The split value for k dimension. Default is 1.
-verify The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 1, validation on CPU, as validation on GPU is not supported.
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
-warmup The number of iterations before benchmark the kernel. Default is 50.
-repeat The number of iterations to benchmark the kernel. Default is 100.
-timer Whether if the timer is gpu timer or not. Possible values are false or true. Default is true.
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
-flush_cache To flush cache, possible values are true or false. Default is false.
-rotating_count Number of iterations to rotate the cache. Default is 5.
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
-csv_filename The filename of benchmark result. Default is gemm_multi_d_kernel.
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
-scheduler The type of scheduler. Possible values are intrawave. Default is intrawave.
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
-pad_n Whether pad or not in n direction. Possible values are true or false. Default is false.
-pad_k Whether pad or not in k direction. Possible values are true or false. Default is false.
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json
```
Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
## Example
The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes.
```json
{
/// other parameters ///
"tile_m": {
"values": [256]
},
"tile_n": {
"values": [256]
},
"tile_k": {
"values": [64, 32]
},
/// other parameters ///
"pipeline": {
"values": ["compv3", "compv4", "mem"]
},
"scheduler": {
"values": ["intrawave", "interwave"]
},
"epilogue": {
"values": ["cshuffle"]
}
}
```
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
``` bash
./bin/benchmark_gemm_multi_d_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=cshuffle
```
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and cshuffle epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <tuple>
#include <exception>
#include "benchmark_gemm_multi_d.hpp"
#include "gemm_multi_d_profiler.hpp"
void benchmark_gemm_multi_d(const ck_tile::ArgParser& arg_parser)
{
GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"),
arg_parser.get_int("m"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("stride_a"),
arg_parser.get_int("stride_b"),
arg_parser.get_int("stride_ds"),
arg_parser.get_int("stride_ds"),
arg_parser.get_int("stride_e"),
DataTypeTraits<ADataType>::name,
DataTypeTraits<BDataType>::name,
DataTypeTraits<D0DataType>::name,
DataTypeTraits<D1DataType>::name,
DataTypeTraits<AccDataType>::name,
DataTypeTraits<EDataType>::name,
ALayout::name,
BLayout::name,
D0Layout::name,
D1Layout::name,
ELayout::name};
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count")};
auto& profiler = GemmMultiDProfiler::instance(setting);
try
{
auto kernel_func = get_kernel_func_by_trait(arg_parser);
profiler.benchmark(gemm_multi_d_problem, kernel_func);
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
benchmark_gemm_multi_d(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,218 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include "gemm_multi_d_host_api.hpp"
struct GemmMultiDProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_e_;
std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_e_;
std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_e_;
friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_d0\":" << problem.stride_d0_ << ",\n"
<< " \"stride_d1\":" << problem.stride_d1_ << ",\n"
<< " \"stride_e\":" << problem.stride_e_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n"
<< " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_e\":\"" << problem.dtype_e_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n"
<< " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n"
<< " \"layout_e\":\"" << problem.layout_e_ << "\"\n"
<< "}";
return os;
}
};
struct Setting
{
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;
int verify_;
int init_method_;
bool log_;
std::string csv_filename_;
bool flush_cache_;
int rotating_count_;
};
// @brief Function to get the kernel output with reference implementation on CPU
void gemm_multi_d_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<D0DataType>& d0_m_n,
ck_tile::HostTensor<D1DataType>& d1_m_n,
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
{
if(verify > 0)
{
// Currently supporting on CPU verification for Gemm Multi D
// e_m_n_host_result.SetZero();
ck_tile::reference_gemm_multiple_d<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ElementWiseFn>(
a_m_k, b_k_n, {d0_m_n, d1_m_n}, e_m_n_host_result);
}
}
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
GemmMultiDProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \"" << "{\n"
<< obj.name_ << "\n}" << "\",\n"
<< " \"problem\": \"" << obj.problem_ << "\",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(e_m_n_host_result.mData.begin(), e_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
bool pass = ck_tile::check_err(e_m_n_dev_result,
e_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}

View File

@@ -0,0 +1,80 @@
{
"tile_config": {
"tile_m": {
"values": [
256 ]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
32
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -0,0 +1,84 @@
{
"tile_config": {
"tile_m": {
"values": [
256
]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
32
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -0,0 +1,81 @@
{
"tile_config": {
"tile_m": {
"values": [
256
]
},
"tile_n": {
"values": [
256
]
},
"tile_k": {
"values": [
64
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
32
]
},
"warp_tile_n": {
"values": [
32
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -0,0 +1,229 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Mappings and utility functions for kernel code generation.
"""
import subprocess
import re
from functools import lru_cache
DATA_TYPE_MAP = {
"fp32": "float",
"fp16": "ck_tile::half_t",
"bf16": "ck_tile::bf16_t",
"int8": "ck_tile::int8_t",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int4": "ck_tile::pk_int4_t",
"int32": "ck_tile::int32_t",
}
LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
# TODO THIS IS NOT SUPPORTED FOR MULTI D AS OF NOW
# DEFAULT_EPILOGUE = """
# using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
# ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
# BDataType,
# AccDataType,
# CDataType,
# CLayout,
# kPadM,
# kPadN,
# WarpTileM,
# WarpTileN,
# WarpTileK,
# UniversalGemmProblem::TransposeC,
# true,
# memory_operation>>;
# """
CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
WarpM,
WarpN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
PIPELINE_MAP = {
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
"compv3": [
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
"ck_tile::GemmPipelineAgBgCrCompV3",
],
"compv4": [
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
"ck_tile::GemmPipelineAgBgCrCompV4",
],
}
SCHEDULER_MAP = {
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
}
# EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
EPILOGUE_MAP = {"cshuffle": CSHUFFLE_EPILOGUE}
def BOOL_MAP(b_):
return {True: "true", False: "false"}[bool(b_)]
# Can add some more supported combinations
warp_tile_supported_combinations = {
"gfx90a": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
}
# Remove some unsupported combinations
trait_unsupported_combinations = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
}
ELEMENT_SIZE_MAP = {
"fp16": 2,
"bf16": 2,
"int8": 1,
"fp8": 1,
"bf8": 1,
"int4": 0.5,
"int32": 4,
}
def element_size(data_type: str) -> float:
"""Calculate the size (in bytes) of a single element for given data type."""
data_type = data_type.lower()
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
print("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
print("GPU query timeout (5s)")
except Exception as e:
print(f"GPU detection error: {str(e)}")
return ""

View File

@@ -0,0 +1,250 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Handles loading, parsing, and validation of JSON and Argument configuration parameters.
"""
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional, Union, Type
import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
exclude: Optional[List[int]]
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
if self.step <= 0:
raise ValueError(f"Step must be positive, got {self.step}")
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, "exclude") and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
if not candidates:
raise ValueError(
f"No valid candidates for range [{self.min}-{self.max}] "
f"with step {self.step} and excludes {self.exclude}"
)
return candidates
@dataclass
class DataType:
"""Configuration class for data type parameter."""
a_datatype: str
b_datatype: str
e_datatype: str
d0_datatype: str
d1_datatype: str
ds_datatype: List[str]
@dataclass
class Layout:
"""Configuration class for Layout parameter."""
a_layout: str
b_layout: str
e_layout: str
d0_layout: str
d1_layout: str
ds_layout: List[str]
@dataclass
class ArgumentConfig:
"""Configuration class for Argument parameter."""
datatypes: DataType
layouts: Layout
function_name: str
@classmethod
def from_args(
cls: Type["ArgumentConfig"],
datatype: str,
layout: str,
elementwise_function: str,
) -> "ArgumentConfig":
"""configuration loader with validation controls"""
datatypes = DataType(
a_datatype=datatype,
b_datatype=datatype,
e_datatype=datatype,
d0_datatype=datatype,
d1_datatype=datatype,
ds_datatype=[datatype, datatype],
)
layout_parts = layout.lower()
assert len(layout_parts) == 4, (
f"Invalid layout string: {layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ("r", "c"), (
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[1] in ("r", "c"), (
f"Invalid matrix_b layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_e layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
assert layout_parts[3] == "r", (
f"Invalid D dimension layout: {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
)
layouts = Layout(
a_layout=layout[0],
b_layout=layout[1],
e_layout=layout[2],
d0_layout=layout[3],
d1_layout=layout[3],
ds_layout=[layout[3], layout[3]],
)
# Elementwise function name validation
valid_functions = ["mul", "add", "passthrough"]
if elementwise_function not in valid_functions:
raise ValueError(
f"Invalid elementwise function: {elementwise_function}. "
f"Valid options are: {', '.join(valid_functions)}"
)
# Set the function name based on the elementwise function
if elementwise_function == "mul":
function_name = "MultiDMultiply"
elif elementwise_function == "add":
function_name = "MultiDAdd"
elif elementwise_function == "passthrough":
function_name = "PassThrough" # TODO Change this
return cls(datatypes=datatypes, layouts=layouts, function_name=function_name)
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
warp_m: Union[EnumConfigParam, RangeConfigParam]
warp_n: Union[EnumConfigParam, RangeConfigParam]
warp_k: Union[EnumConfigParam, RangeConfigParam]
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
@dataclass
class JsonConfig:
"""Configuration class for JSON parameter."""
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["JsonConfig"], filepath: str) -> "JsonConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open("r") as f:
config_dict = json.load(f)
# Parse tile config
def create_param(param_dict):
if "values" in param_dict:
return EnumConfigParam(values=param_dict["values"])
else:
return RangeConfigParam(
min=param_dict["min"],
max=param_dict["max"],
step=param_dict["step"],
exclude=param_dict.get("exclude", []),
)
tile_config = TileConfig(
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict["trait_config"]["pipeline"]["values"]
),
scheduler=EnumConfigParam(
values=config_dict["trait_config"]["scheduler"]["values"]
),
epilogue=EnumConfigParam(
values=config_dict["trait_config"]["epilogue"]["values"]
),
pad_m=EnumConfigParam(
values=config_dict["trait_config"]["pad_m"]["values"]
),
pad_n=EnumConfigParam(
values=config_dict["trait_config"]["pad_n"]["values"]
),
pad_k=EnumConfigParam(
values=config_dict["trait_config"]["pad_k"]["values"]
),
)
return cls(tile_config=tile_config, trait_config=trait_config)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}")
except KeyError as e:
raise KeyError(f"Missing required configuration field: {str(e)}")

View File

@@ -0,0 +1,164 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstring>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_multi_d_dispatcher.hpp"
#include "gemm_multi_d_common.hpp"
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
.insert("n", "4096", "The value for n dimension. Default is 4096.")
.insert("k", "2048", "The value for k dimension. Default is 2048.")
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_ds", "0", "The stride value for tensor Ds Default is 0.")
.insert("stride_e", "0", "The stride value for tensor E Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
"1",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU. Default is 1, validation on CPU, as validation on GPU is "
"not supported.")
.insert("log",
"false",
"Wether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert("warmup",
"50",
"The number of iterations before benchmarking the kernel. Default is 50.")
.insert("repeat",
"100",
"The number of iterations for benchmarking the kernel. Default is 100.")
.insert("timer",
"true",
"Indicates whether the timer is a GPU timer. Possible values are true or false. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"false",
"To flush cache, possible values are true or false. "
"Default is false.")
.insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"gemm_multi_d_kernel",
"The filename of benchmark result. Default is set to gemm_multi_d_kernel.")
.insert(
"pipeline",
"compv3",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
.insert("scheduler",
"intrawave",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
"compv3.")
.insert(
"epilogue",
"cshuffle",
"The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.")
.insert("pad_m",
"false",
"Whether pad or not in m direction. Possible values are true or false. Default is "
"false.")
.insert("pad_n",
"false",
"Whether pad or not in n direction. Possible values are true or false. Default is "
"false.")
.insert("pad_k",
"false",
"Whether pad or not in k direction. Possible values are true or false. Default is "
"false.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
{
KernelTraits trait;
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.pad_m = arg_parser.get_bool("pad_m");
trait.pad_n = arg_parser.get_bool("pad_n");
trait.pad_k = arg_parser.get_bool("pad_k");
return GemmMultiDDispatcher::dispatch(trait);
}

View File

@@ -0,0 +1,755 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
generate kernel instances to speed up compilation
"""
import argparse
import itertools
from pathlib import Path
from typing import List, Optional
from gemm_multi_d_config import JsonConfig, ArgumentConfig, RangeConfigParam
from gemm_multi_d_codegen_utils import (
DATA_TYPE_MAP,
LAYOUT_MAP,
PIPELINE_MAP,
SCHEDULER_MAP,
EPILOGUE_MAP,
BOOL_MAP,
warp_tile_supported_combinations,
trait_unsupported_combinations,
element_size,
get_gpu_name_by_id,
)
import logging
logging.basicConfig(level=logging.INFO)
class GemmMultiDCodeGenerator:
"""GEMM (General Matrix Multiplication) Multi D code generator."""
def __init__(
self,
args: argparse.Namespace,
user_provided_config: Optional[JsonConfig] = None,
):
self.output_dir = Path(args.working_path)
self.output_dir.mkdir(parents=True, exist_ok=True)
if user_provided_config is not None:
self.config = user_provided_config
else:
config_path = (
Path(__file__).resolve().parent / "configs" / "default_config.json"
)
self.config = JsonConfig.from_json(config_path)
self.args = ArgumentConfig.from_args(
args.datatype, args.layout, args.elementwise_function
)
self.valid_trait_names: List[str] = []
self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {}
def list_all_trait_names(self):
"""List all possible kernel trait names into file."""
w_p = Path(self.output_dir)
file_path = w_p / "gemm_multi_d_instance_blobs.txt"
self._generate_all_traits()
self._get_valid_trait_tile_combinations()
file_range_map = {}
# Write all file paths to the header file
files_listed = 0
with file_path.open("w") as f:
# Core files
core_files = [
"gemm_multi_d_common.hpp",
"gemm_multi_d_instances.hpp",
"gemm_multi_d_dispatcher.hpp",
]
for core_file in core_files:
f.write(str(w_p / core_file) + "\n")
files_listed += 1
# Trait header files
for trait in self.valid_trait_names:
trait_file = f"gemm_multi_d_{trait}.hpp"
f.write(str(w_p / trait_file) + "\n")
files_listed += 1
file_name = set()
# Instance source files
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
start_idx = files_listed
for tile in tile_valid_params:
for (
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
_,
_,
_,
) in tile:
instance_name = f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
if instance_name not in file_name:
file_name.add(instance_name)
f.write(str(w_p / instance_name) + "\n")
files_listed += 1
file_range_map[trait] = (start_idx, files_listed)
file_path = w_p / "gemm_multi_d_instance_blobs_range.txt"
with file_path.open("w") as f:
for name, ranges in file_range_map.items():
start, last = ranges
f.write(name + " " + f"{start}" + " " + f"{last}" + "\n")
def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
# Generate all unique_combinations
_unique = set(
itertools.product(
*[getattr(self.config.trait_config, param).values for param in params]
)
)
for combo in _unique:
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
current_combination = (pipeline, epilogue, scheduler)
if current_combination not in trait_unsupported_combinations:
trait_name = (
f"{pipeline}_{epilogue}_{scheduler}_"
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
)
self.valid_trait_names.append(trait_name)
else:
logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}")
def _get_valid_trait_tile_combinations(self):
def get_tile_value(tile_param):
return (
tile_param.generate_candidates()
if isinstance(tile_param, RangeConfigParam)
else tile_param.values
)
tile_group = list(
itertools.product(
get_tile_value(self.config.tile_config.tile_m),
get_tile_value(self.config.tile_config.tile_n),
get_tile_value(self.config.tile_config.tile_k),
)
)
warp_group = list(
itertools.product(
get_tile_value(self.config.tile_config.warp_m),
get_tile_value(self.config.tile_config.warp_n),
get_tile_value(self.config.tile_config.warp_k),
)
)
warp_tile_group = list(
itertools.product(
get_tile_value(self.config.tile_config.warp_tile_m),
get_tile_value(self.config.tile_config.warp_tile_n),
get_tile_value(self.config.tile_config.warp_tile_k),
)
)
tile_params = {
t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group
}
for trait in self.valid_trait_names:
tile_valid_params = [
tile for tile in tile_params if self.is_tile_valid(tile, trait)
]
if trait not in self.valid_trait_tile_combinations:
self.valid_trait_tile_combinations[trait] = []
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
def is_tile_valid(self, tile: tuple, trait: str) -> bool:
"""Check if the tile configuration is valid for the given trait."""
(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) = tile
pipeline, *_ = trait.split("_")
# Parameter validity check
invalid_params = []
if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]:
invalid_params.append(
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})"
)
if (warp_m * warp_tile_m) == 0:
invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
if (warp_n * warp_tile_n) == 0:
invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
if (warp_k * warp_tile_k) == 0:
invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
if invalid_params:
logging.debug(
f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. "
f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), "
f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})"
)
return False
# Dimension alignment check
alignment_issues = []
if tile_m % (warp_m * warp_tile_m) != 0:
alignment_issues.append(
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
)
if tile_n % (warp_n * warp_tile_n) != 0:
alignment_issues.append(
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
)
if tile_k % (warp_k * warp_tile_k) != 0:
alignment_issues.append(
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
)
if alignment_issues:
logging.debug(
f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. "
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
)
return False
# LDS capacity verification
matrix_a_size = (tile_m * tile_k) * element_size(self.args.datatypes.a_datatype)
matrix_b_size = (tile_n * tile_k) * element_size(self.args.datatypes.b_datatype)
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
if total_tile_in_lds > max_tile_size:
logging.debug(
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
)
return False
# Warp combination validation
warp_tile_key = f"{self.args.datatypes.a_datatype}_{self.args.datatypes.b_datatype}_{self.args.datatypes.e_datatype}"
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
gpu_name = get_gpu_name_by_id(0)
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
if not gpu_warp_tile_key:
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
)
return False
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
if not allowed_combinations:
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
)
return False
if current_combination not in allowed_combinations:
logging.debug(
f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. "
f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}"
)
return False
return True
def generate_all_instance_files(self):
"""Generate all kernel instances files."""
self._generate_common_header_file()
self._generate_all_trait_files()
self._generate_dispatcher_file()
def _generate_common_header_file(self):
"""Generate common header file with datatypes and layout."""
acc_type = "float" # As we are currently supporting only fp16
content = f"""
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
// Data types
using ADataType = {DATA_TYPE_MAP[self.args.datatypes.a_datatype]};
using BDataType = {DATA_TYPE_MAP[self.args.datatypes.b_datatype]};
using AccDataType = {acc_type};
using D0DataType = {DATA_TYPE_MAP[self.args.datatypes.d0_datatype]};
using D1DataType = {DATA_TYPE_MAP[self.args.datatypes.d1_datatype]};
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using EDataType = {DATA_TYPE_MAP[self.args.datatypes.e_datatype]};
// Layout configurations
using ALayout = {LAYOUT_MAP[self.args.layouts.a_layout]};
using BLayout = {LAYOUT_MAP[self.args.layouts.b_layout]};
using D0Layout = {LAYOUT_MAP[self.args.layouts.d0_layout]};
using D1Layout = {LAYOUT_MAP[self.args.layouts.d1_layout]};
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using ELayout = {LAYOUT_MAP[self.args.layouts.e_layout]};
// Element-wise function for D
using ElementWiseFn = ck_tile::element_wise::{self.args.function_name};
"""
(self.output_dir / "gemm_multi_d_common.hpp").write_text(content)
def _generate_all_trait_files(self):
"""Generate all kernel traits into files."""
if not self.valid_trait_names:
self._generate_all_traits()
self._get_valid_trait_tile_combinations()
for trait in self.valid_trait_names:
self._generate_trait_file(trait)
self._generate_instantiation_source_files()
self._generate_common_instance_header_file()
def _generate_trait_file(self, trait: str):
"""Generate a trait with all tile/warp combinations."""
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
filename = f"gemm_multi_d_{trait}.hpp"
content = f"""
#pragma once
#include "gemm_multi_d_common.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/host.hpp"
namespace {trait} {{
"""
# Add template struct with configuration
content += self._generate_kernel_struct(
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
)
content += f"\n}} // namespace {trait}\n"
(self.output_dir / filename).write_text(content)
def _generate_kernel_struct(
self,
pipeline: str,
epilogue: str,
scheduler: str,
pad_m: str,
pad_n: str,
pad_k: str,
) -> str:
"""Generate the code block of kernel struct"""
return f"""
template <int TileM, int TileN, int TileK,
int WarpM, int WarpN, int WarpK,
int WarpTileM, int WarpTileN, int WarpTileK,
typename CDEElementWise = ElementWiseFn>
struct GemmKernelMultiD {{
static constexpr bool kPadM = {pad_m};
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static float launch(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
static constexpr bool TransposeC = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<TileM, TileN, TileK>,
ck_tile::sequence<WarpM, WarpN, WarpK>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>;
using Traits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, ELayout>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, ELayout, TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {PIPELINE_MAP[pipeline][1]}<UniversalGemmProblem>;
{EPILOGUE_MAP[epilogue]}
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
if(stream.log_level_ > 0)
{{
std::cout << "Launching kernel with args:"
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
ave_time = ck_tile::launch_kernel(stream,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
Kernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}}
static std::string get_name() {{
return std::string("gemm_multi_d_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) +
"_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" +
std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" +
"{pad_m}" + "_" +
"{pad_n}" + "_" +
"{pad_k}" + "_" +
"{pipeline}" + "_" +
"{epilogue}" + "_" +
"{scheduler}";
}}
}};
"""
def _generate_instantiation_source_files(self):
"""Generate kernel instance instantiation source files"""
tile_map = {}
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
for (
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) in tile:
key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}"
value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
if key not in tile_map:
tile_map[key] = set()
tile_map[key].add(value)
files_listed = 0
for trait, _ in self.valid_trait_tile_combinations.items():
for block_tile, warp_tiles in tile_map.items():
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(
int, block_tile.split("x")
)
content = f"""
#include "gemm_multi_d_{trait}.hpp"
"""
for warp_tile in warp_tiles:
warp_tile_m, warp_tile_n, warp_tile_k = map(
int, warp_tile.split("x")
)
files_listed = files_listed + 1
content = (
content
+ f"""
template struct {trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>;"""
)
content += """
"""
(
self.output_dir
/ f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
).write_text(content)
print(f"Generated {files_listed} kernel instances in total.")
def _generate_common_instance_header_file(self):
"""Generate common instance header into file."""
content = """
#pragma once
"""
for trait in self.valid_trait_names:
content += f'#include "gemm_multi_d_{trait}.hpp"\n'
(self.output_dir / "gemm_multi_d_instances.hpp").write_text(content)
def _generate_dispatcher_file(self):
"""Generate the code block of dispatch mechanism."""
content = """
#pragma once
#include <unordered_map>
#include <functional>
#include <vector>
#include "gemm_multi_d_common.hpp"
#include "gemm_multi_d_instances.hpp"
/// @brief Defines the configuration parameters for a GEMM Multi D operation, enabling the selection of a
/// specific kernel instance based on the provided settings.
struct KernelTraits
{
/// @brief The name of the pipeline.
std::string pipeline;
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
std::string scheduler;
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
std::string epilogue;
/// @brief Indicates whether padding is applied to the M dimension.
bool pad_m;
/// @brief Indicates whether padding is applied to the N dimension.
bool pad_n;
/// @brief Indicates whether padding is applied to the K dimension.
bool pad_k;
};
struct GemmMultiDDispatcher {
static auto& get_kernel_map() {
// Use a static local variable
static std::unordered_map<
std::string,
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>>
kernel_map;
return kernel_map;
}
static void init() {
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
content += f""" kernel_map["{trait}"] = {{"""
for _, tile in enumerate(tile_valid_params):
for j in range(len(tile)):
(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) = tile[j]
content += """[=](ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) { """
content += f"""
return run_kernel<{trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>>(args, stream);"""
if j == len(tile) - 1:
content += """
} """
else:
content += """
}, """
content += """
};\n """
content += """ }
template <typename Kernel>
static std::tuple<std::string, float> run_kernel(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream)
{
std::string name = Kernel::get_name();
float avg_time = Kernel::launch(args, stream);
return std::make_tuple(name, avg_time);
}
static auto dispatch(const KernelTraits& trait) {
init();
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
if(auto it = kernel_map.find(key); it != kernel_map.end())
{
return it->second;
}
throw std::runtime_error("No suitable kernel found: " + key);
}
private:
static std::string assemble_key(const KernelTraits &trait) {
return std::string(trait.pipeline) + "_" +
trait.epilogue + "_" +
trait.scheduler + "_" +
(trait.pad_m ? "true" : "false") + "_" +
(trait.pad_n ? "true" : "false") + "_" +
(trait.pad_k ? "true" : "false");
}
};
"""
(self.output_dir / "gemm_multi_d_dispatcher.hpp").write_text(content)
def do_list_blobs(
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
):
generator = GemmMultiDCodeGenerator(args, user_provide_config)
generator.list_all_trait_names()
def do_gen_blobs(
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
):
generator = GemmMultiDCodeGenerator(args, user_provide_config)
generator.generate_all_instance_files()
def main(args):
gemm_multi_d_config = JsonConfig.from_json(args.config_json)
if args.list_blobs:
do_list_blobs(args, gemm_multi_d_config)
elif args.gen_blobs:
do_gen_blobs(args, gemm_multi_d_config)
else:
logging.warning(
"No mode specified (use --list_blobs or --gen_blobs). Generating by default..."
)
do_gen_blobs(args, gemm_multi_d_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK gemm multi D kernel",
)
parser.add_argument(
"-w",
"--working_path",
default="./",
required=False,
help="The path where all the blobs are going to be generated",
)
parser.add_argument(
"-j",
"--config_json",
required=False,
help="Path to the json which contains the configurations that user provide",
)
parser.add_argument(
"-d",
"--datatype",
required=True,
help="Specify what datatype to use for the kernel generation, e.g. fp16",
)
parser.add_argument(
"-ly",
"--layout",
required=True,
help="Specify what layout to use for the kernel generation, e.g. rcrr, rrrr",
)
parser.add_argument(
"-ef",
"--elementwise_function",
required=True,
help="Specify what element wise function for D, e.g. mul, add, passthrough",
)
parser.add_argument(
"-l",
"--list_blobs",
action="store_true",
help="List all kernel instances to file",
)
parser.add_argument(
"-g",
"--gen_blobs",
action="store_true",
help="Generate all kernel instances into different files",
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,278 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <fstream>
#include <iomanip>
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "benchmark_gemm_multi_d.hpp"
class GemmMultiDProfiler
{
public:
static GemmMultiDProfiler& instance(Setting setting)
{
static GemmMultiDProfiler instance{setting};
return instance;
}
void benchmark(
GemmMultiDProblem& gemm_multi_d_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>&
callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const D0Layout layout_d0 = D0Layout{};
const D1Layout layout_d1 = D1Layout{};
const ELayout layout_e = ELayout{};
gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
is_row_major(layout_a));
gemm_multi_d_problem.stride_b_ = ck_tile::get_default_stride(gemm_multi_d_problem.k_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_b_,
is_row_major(layout_b));
gemm_multi_d_problem.stride_d0_ =
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d0_,
is_row_major(layout_d0));
gemm_multi_d_problem.stride_d1_ =
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d1_,
is_row_major(layout_d1));
gemm_multi_d_problem.stride_e_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_e_,
is_row_major(layout_e));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.k_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_b_,
is_row_major(layout_b)));
ck_tile::HostTensor<D0DataType> d0_m_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d0_,
is_row_major(layout_d0)));
ck_tile::HostTensor<D1DataType> d1_m_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d1_,
is_row_major(layout_d1)));
ck_tile::HostTensor<EDataType> e_m_n_device_result(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_e_,
is_row_major(layout_e)));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(d1_m_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.mData.data());
b_k_n_dev_buf.ToDevice(b_k_n.mData.data());
d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data());
e_m_n_dev_buf.SetZero();
e_m_n_device_result.SetZero();
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {
gemm_multi_d_problem.stride_d0_, gemm_multi_d_problem.stride_d1_};
ck_tile::GemmMultiDHostArgs<DsDataType::size()> gemm_multi_d_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
gemm_multi_d_problem.split_k_,
gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
gemm_multi_d_problem.stride_b_,
stridesDs,
gemm_multi_d_problem.stride_e_,
};
ck_tile::HostTensor<EDataType> e_m_n_host_result(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_e_,
is_row_major(layout_e)));
if(setting_.verify_)
{
gemm_multi_d_host_reference(
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, e_m_n_host_result);
}
for(auto& callable : callables)
{
auto kernel_run_result =
callable(gemm_multi_d_args,
ck_tile::stream_config{
nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_});
auto [kernel_name, execution_time] = kernel_run_result;
process_result(gemm_multi_d_problem,
e_m_n_dev_buf,
e_m_n_host_result,
e_m_n_device_result,
kernel_run_result);
}
}
void process_result(const GemmMultiDProblem& gemm_multi_d_problem,
ck_tile::DeviceMem& e_m_n_dev_buf,
ck_tile::HostTensor<EDataType>& e_m_n_host_result,
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}};
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
std::size_t flop = 0, num_byte = 0;
flop += std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ *
gemm_multi_d_problem.k_;
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
num_byte += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
});
num_byte += sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ +
sizeof(BDataType) * gemm_multi_d_problem.k_ * gemm_multi_d_problem.n_ +
sizeof(EDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
kernel_instance.perf_result_.latency_ = avg_time;
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
if(setting_.log_ > 0)
{
std::cout << kernel_instance << std::endl;
}
e_m_n_dev_buf.FromDevice(e_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(name, gemm_multi_d_problem.k_, e_m_n_dev_result, e_m_n_host_result);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << name << std::endl;
}
e_m_n_dev_buf.SetZero();
e_m_n_dev_result.SetZero();
}
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
kernel_instances_.end(),
[metric](const auto& a, const auto& b) {
return PerformanceResult::compare(
b.perf_result_, a.perf_result_, metric);
});
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "The best kernel instance is: " << kernel_instance << std::endl;
std::cout << "**********************************" << std::endl;
if(!setting_.csv_filename_.empty())
{
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
if(!file.is_open())
{
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
}
else
{
if(file.tellp() == 0)
{
file << "rocm_version,device_name,"
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
<< "structured_sparsity," << "name,"
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
}
const auto& problem = kernel_instance.problem_;
const auto& name = kernel_instance.name_;
const auto& perf = kernel_instance.perf_result_;
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
<< problem.stride_d0_ << "," << problem.stride_d1_ << "," << problem.stride_e_
<< "," << problem.dtype_a_ << "," << problem.dtype_b_ << ","
<< problem.dtype_d0_ << "," << problem.dtype_d1_ << "," << problem.dtype_acc_
<< "," << problem.dtype_e_ << "," << problem.layout_a_ << ","
<< problem.layout_b_ << "," << problem.layout_d0_ << "," << problem.layout_d1_
<< "," << problem.layout_e_ << "," << "," << name << "," << std::fixed
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
<< "\n";
if(!file)
{
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
}
}
}
return kernel_instance;
}
GemmMultiDProfiler(const GemmMultiDProfiler&) = delete;
GemmMultiDProfiler& operator=(const GemmMultiDProfiler&) = delete;
private:
~GemmMultiDProfiler() { kernel_instances_.clear(); }
GemmMultiDProfiler(Setting setting) : setting_(setting) {}
Setting setting_;
std::vector<KernelInstance> kernel_instances_;
};