mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
only support profile
This commit is contained in:
@@ -1,36 +1,34 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
# execute_process(
|
||||
# COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
# --working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
# --list_blobs
|
||||
# RESULT_VARIABLE ret
|
||||
# )
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
# if(ret AND NOT ret EQUAL 0)
|
||||
# message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
# endif()
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
|
||||
# add_custom_command(
|
||||
# OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
# COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
# --working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
# --gen_blobs
|
||||
# DEPENDS ${GEMM_CODEGEN_BLOBS}
|
||||
# )
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
--gen_blobs
|
||||
DEPENDS ${GEMM_CODEGEN_BLOBS}
|
||||
)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
|
||||
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
|
||||
|
||||
# use build as include directory
|
||||
find_package(SQLite3 REQUIRED)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp)
|
||||
target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_link_libraries(${EXECUTABLE_GEMM_INSTANCE} SQLite::SQLite3)
|
||||
target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS})
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
|
||||
@@ -28,10 +28,11 @@ make tile_engine_gemm -j
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-split_k SplitK value (default:1)
|
||||
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
|
||||
-metric The metric value of kernel performance - latency: 0, tflops: 1, bandwidth: 2 (default:0)
|
||||
-warmup Number of iterations before benchmark the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
|
||||
-init Value for initializing tensor - random: 0, linear: 1, constant(1): 2 (default:0)
|
||||
-pipeline possible values are: compv3, compv4, mem (default:compv3)
|
||||
-scheduler possible values are: intrawave, interwave (default:intrawave)
|
||||
-epilogue possible values are: cshuffle, default (default:cshuffle)
|
||||
|
||||
@@ -8,17 +8,15 @@
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/version.h"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "profile_cache.hpp"
|
||||
|
||||
class Executor
|
||||
class Profiler
|
||||
{
|
||||
public:
|
||||
~Executor() { kernel_instances_.clear(); }
|
||||
|
||||
static Executor& instance(bool enable_profile_cache = true, bool flush_profile_cache = false)
|
||||
static Profiler& instance()
|
||||
{
|
||||
static Executor instance{enable_profile_cache, flush_profile_cache};
|
||||
static Profiler instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
@@ -61,108 +59,39 @@ class Executor
|
||||
|
||||
KernelInstance kernel_instance{environment_, description, problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
auto launch_kernel = [&] {
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / avg_time;
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / avg_time;
|
||||
|
||||
kernel_instance.perf_result.latency = avg_time;
|
||||
kernel_instance.perf_result.tflops = tflops;
|
||||
kernel_instance.perf_result.bandwidth = gb_per_sec;
|
||||
kernel_instance.perf_result.latency = avg_time;
|
||||
kernel_instance.perf_result.tflops = tflops;
|
||||
kernel_instance.perf_result.bandwidth = gb_per_sec;
|
||||
|
||||
std::cout << kernel_instance << std::endl;
|
||||
std::cout << kernel_instance << std::endl;
|
||||
|
||||
bool verified_correct =
|
||||
!verify || compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
|
||||
bool verified_correct =
|
||||
!verify || compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << description << std::endl;
|
||||
}
|
||||
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
};
|
||||
|
||||
if(enable_profile_cache_)
|
||||
if(verified_correct)
|
||||
{
|
||||
if(!cache_db_->check_if_record(kernel_instance))
|
||||
{
|
||||
|
||||
launch_kernel();
|
||||
cache_db_->insert_batch({kernel_instance});
|
||||
}
|
||||
else
|
||||
{
|
||||
auto perf_result = cache_db_->query_performance_result(kernel_instance);
|
||||
kernel_instance.perf_result.latency = perf_result.latency;
|
||||
kernel_instance.perf_result.tflops = perf_result.tflops;
|
||||
kernel_instance.perf_result.bandwidth = perf_result.bandwidth;
|
||||
std::cout << "Skip this kernel for " << description
|
||||
<< ", Because it has already been recorded in the cache database"
|
||||
<< std::endl;
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
launch_kernel();
|
||||
std::cout << "Verification failed, skip kernel: " << description << std::endl;
|
||||
}
|
||||
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
void export_perf_to_csv(const std::vector<KernelInstance>& instances,
|
||||
const std::string& filename)
|
||||
{
|
||||
|
||||
std::ostringstream buffer;
|
||||
|
||||
buffer << "ROCmVersion,CommitID,DeviceName,KernelName,SplitK,M,N,K,"
|
||||
<< "StrideA,StrideB,StrideC,ADataType,BDataType,AccDataType,"
|
||||
<< "CDataType,ALayout,BLayout,CLayout,Latency(ms),TFLOPS,Bandwidth\n";
|
||||
|
||||
for(const auto& instance : instances)
|
||||
{
|
||||
const auto& env = instance.env;
|
||||
const auto& p = instance.problem;
|
||||
const auto& perf = instance.perf_result;
|
||||
|
||||
std::string sanitized_name = instance.name;
|
||||
std::replace(sanitized_name.begin(), sanitized_name.end(), '\"', '\'');
|
||||
|
||||
buffer << env.rocm_version << "," << env.commit_id << "," << env.device_name << ","
|
||||
<< "\"" << sanitized_name << "\"," << p.split_k << "," << p.m << "," << p.n
|
||||
<< "," << p.k << "," << p.stride_a << "," << p.stride_b << "," << p.stride_c
|
||||
<< "," << p.dtype_a << "," << p.dtype_b << "," << p.dtype_acc << "," << p.dtype_c
|
||||
<< "," << p.layout_a << "," << p.layout_b << "," << p.layout_c << ","
|
||||
<< std::fixed << std::setprecision(6) << perf.latency << "," << std::scientific
|
||||
<< perf.tflops << "," << std::fixed << perf.bandwidth << "\n";
|
||||
}
|
||||
|
||||
std::ofstream csv_file(filename, std::ios::trunc);
|
||||
if(!csv_file)
|
||||
{
|
||||
throw std::runtime_error("Failed to open CSV file: " + filename);
|
||||
}
|
||||
csv_file << buffer.str();
|
||||
csv_file.close();
|
||||
|
||||
if(csv_file.fail())
|
||||
{
|
||||
throw std::runtime_error("Incomplete write to CSV file: " + filename);
|
||||
}
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric, const std::string& csv_path = "")
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
@@ -174,126 +103,28 @@ class Executor
|
||||
b.perf_result, a.perf_result, metric);
|
||||
});
|
||||
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "The best kernel instance is: " << kernel_instance << std::endl;
|
||||
|
||||
if(!csv_path.empty())
|
||||
{
|
||||
try
|
||||
{
|
||||
export_perf_to_csv(kernel_instances_, csv_path);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "CSV export failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
std::cout << "**********************************" << std::endl;
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
private:
|
||||
Executor(bool enable_profile_cache = true, bool flush_profile_cache = false)
|
||||
: enable_profile_cache_(enable_profile_cache), flush_profile_cache_(flush_profile_cache)
|
||||
Profiler()
|
||||
{
|
||||
environment_ = Environment{
|
||||
get_rocm_version(),
|
||||
"89f",
|
||||
ck::get_device_name(),
|
||||
};
|
||||
std::cout << "Init gemm bechmark on device: " << environment_.device_name << std::endl;
|
||||
|
||||
initialize_profile_cache();
|
||||
}
|
||||
~Profiler() { kernel_instances_.clear(); }
|
||||
|
||||
void initialize_profile_cache()
|
||||
{
|
||||
// Init cache if enable profile cache
|
||||
if(enable_profile_cache_)
|
||||
{
|
||||
// get profile cache path
|
||||
std::filesystem::path cache_db_prefix_path =
|
||||
std::filesystem::current_path() / ".tile_engine";
|
||||
if(!create_cache_directory(cache_db_prefix_path))
|
||||
{
|
||||
std::cerr << "Error: Failed to create cache directory" << std::endl;
|
||||
return;
|
||||
}
|
||||
std::filesystem::path cache_db_path =
|
||||
cache_db_prefix_path / ("tile_engine_" + environment_.device_name + ".db");
|
||||
|
||||
// remove cache if flush_profile_cache
|
||||
handle_cache_flush(cache_db_path);
|
||||
|
||||
// load profile cache
|
||||
initialize_cache_db(cache_db_path);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Executor disable profile cache! " << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool create_cache_directory(const std::filesystem::path& cache_db_prefix_path)
|
||||
{
|
||||
std::error_code ec;
|
||||
bool created = std::filesystem::create_directories(cache_db_prefix_path, ec);
|
||||
|
||||
if(ec)
|
||||
{
|
||||
std::cerr << "Error creating directory " << cache_db_prefix_path << ": " << ec.message()
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(created)
|
||||
{
|
||||
std::cout << "Created cache directory: " << cache_db_prefix_path << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Using existing cache directory: " << cache_db_prefix_path << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void handle_cache_flush(const std::filesystem::path& cache_db_path) const
|
||||
{
|
||||
if(flush_profile_cache_ && std::filesystem::exists(cache_db_path))
|
||||
{
|
||||
std::error_code ec;
|
||||
if(std::filesystem::remove(cache_db_path, ec))
|
||||
{
|
||||
std::cout << "Successfully flushed cache: " << cache_db_path << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Error flushing cache: " << ec.message() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_cache_db(const std::filesystem::path& path)
|
||||
{
|
||||
try
|
||||
{
|
||||
cache_db_ = std::make_unique<ProfileCacheDB>(path);
|
||||
std::cout << "Loaded profile cache from " << path << std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Failed to initialize profile cache: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
Executor(const Executor&) = delete;
|
||||
Executor& operator=(const Executor&) = delete;
|
||||
Profiler(const Profiler&) = delete;
|
||||
Profiler& operator=(const Profiler&) = delete;
|
||||
|
||||
Environment environment_;
|
||||
|
||||
bool enable_profile_cache_;
|
||||
bool flush_profile_cache_;
|
||||
|
||||
std::unique_ptr<ProfileCacheDB> cache_db_;
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
|
||||
@@ -11,22 +11,12 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
int metric,
|
||||
bool enable_profile_cache,
|
||||
bool flush_profile_cache,
|
||||
KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
return GemmDispatcher::dispatch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
metric,
|
||||
enable_profile_cache,
|
||||
flush_profile_cache,
|
||||
trait,
|
||||
args,
|
||||
s);
|
||||
return GemmDispatcher::dispatch(
|
||||
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, metric, trait, args, s);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
@@ -55,8 +45,6 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
int verify = arg_parser.get_int("v");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
int metric = arg_parser.get_int("metric");
|
||||
bool enable_profile_cache = arg_parser.get_bool("enable_profile_cache");
|
||||
bool flush_profile_cache = arg_parser.get_bool("flush_profile_cache");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -168,8 +156,6 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
metric,
|
||||
enable_profile_cache,
|
||||
flush_profile_cache,
|
||||
trait,
|
||||
gemm_args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
@@ -94,7 +94,6 @@ struct PerformanceResult
|
||||
struct Environment
|
||||
{
|
||||
std::string rocm_version;
|
||||
std::string commit_id;
|
||||
std::string device_name;
|
||||
|
||||
std::string serialize() const
|
||||
@@ -102,7 +101,6 @@ struct Environment
|
||||
std::ostringstream oss;
|
||||
oss << "{"
|
||||
<< "\"rocm_version\":\"" << rocm_version << "\","
|
||||
<< "\"commit_id\":\"" << commit_id << "\","
|
||||
<< "\"device_name\":\"" << device_name << "\""
|
||||
<< "}";
|
||||
return oss.str();
|
||||
@@ -239,13 +237,7 @@ inline auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("metric", "latency", "latency, tflops, bandwidth")
|
||||
.insert("enable_profile_cache",
|
||||
"false",
|
||||
"whether use profile cache or not when benchmark kernel")
|
||||
.insert("flush_profile_cache",
|
||||
"false",
|
||||
"whether flush profile cache or not when benchmark kernel")
|
||||
.insert("metric", "0", "0:latency, 1:tflops, 2:bandwidth")
|
||||
.insert("pipeline", "compv3", "compv3, compv4, mem")
|
||||
.insert("scheduler", "intrawave", "intrawave, interwave")
|
||||
.insert("epilogue", "cshuffle", "cshuffle, default")
|
||||
|
||||
@@ -476,21 +476,27 @@ struct GemmKernel {{
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_common.hpp"
|
||||
#include "gemm_instances.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "gemm_common.hpp"
|
||||
#include "gemm_instances.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
#include "benchmark_gemm.hpp"
|
||||
|
||||
struct GemmDispatcher {
|
||||
static auto& get_kernel_map() {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<std::string,
|
||||
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
|
||||
static std::unordered_map<std::string,
|
||||
std::function<void(Profiler&,
|
||||
ck_tile::DeviceMem&,
|
||||
ck_tile::HostTensor<CDataType>&,
|
||||
ck_tile::HostTensor<CDataType>&,
|
||||
int,
|
||||
ck_tile::GemmHostArgs&,
|
||||
const ck_tile::stream_config&)>>
|
||||
kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
@@ -513,7 +519,8 @@ struct GemmDispatcher {
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
content += f""" kernel_map["{group}"] = [](Profiler& profiler,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args,
|
||||
@@ -526,46 +533,29 @@ struct GemmDispatcher {
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
|
||||
profiler.benchmark_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
|
||||
content += f"""
|
||||
}};\n"""
|
||||
|
||||
content += """ }
|
||||
|
||||
template <typename Kernel>
|
||||
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
float avg_time = Kernel::launch(args, s);
|
||||
std::string description = Kernel::get_name();
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
if(verify)
|
||||
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
|
||||
int verify,
|
||||
int metric,
|
||||
const KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& gemm_args,
|
||||
const ck_tile::stream_config& s) {
|
||||
init();
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
auto& profiler = Profiler::instance();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
|
||||
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
|
||||
it->second(
|
||||
profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, s);
|
||||
profiler.select_best_instance(static_cast<Metric>(metric));
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
@@ -1,294 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <sqlite3.h>
|
||||
|
||||
#include "gemm_host_api.hpp"
|
||||
|
||||
#define CHECK_SQLITE3(expr, db) \
|
||||
do \
|
||||
{ \
|
||||
int result_code = (expr); \
|
||||
if(result_code != SQLITE_OK) \
|
||||
{ \
|
||||
const char* err = sqlite3_errmsg(db); \
|
||||
throw std::runtime_error("SQLite error[" + std::to_string(result_code) + \
|
||||
"]: " + (err ? err : "unknown error") + " at " + \
|
||||
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define CHECK_SQLITE3_RC(expr, db, rc) \
|
||||
do \
|
||||
{ \
|
||||
rc = (expr); \
|
||||
if(rc != SQLITE_OK && rc != SQLITE_ROW && rc != SQLITE_DONE) \
|
||||
{ \
|
||||
const char* err = sqlite3_errmsg(db); \
|
||||
throw std::runtime_error("SQLite error[" + std::to_string(rc) + \
|
||||
"]: " + (err ? err : "unknown error") + " at " + \
|
||||
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
class StmtWrapper
|
||||
{
|
||||
public:
|
||||
explicit StmtWrapper(sqlite3* db, const char* sql)
|
||||
: stmt_(
|
||||
[db, sql] {
|
||||
sqlite3_stmt* stmt = nullptr;
|
||||
CHECK_SQLITE3(sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr), db);
|
||||
return stmt;
|
||||
}(),
|
||||
&sqlite3_finalize)
|
||||
{
|
||||
}
|
||||
|
||||
operator sqlite3_stmt*() const { return stmt_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<sqlite3_stmt, decltype(&sqlite3_finalize)> stmt_;
|
||||
};
|
||||
|
||||
class ProfileCacheDB
|
||||
{
|
||||
public:
|
||||
explicit ProfileCacheDB(const std::filesystem::path& path)
|
||||
: db_ptr_(
|
||||
[path] {
|
||||
sqlite3* raw_db_ptr = nullptr;
|
||||
CHECK_SQLITE3(sqlite3_open_v2(path.string().c_str(),
|
||||
&raw_db_ptr,
|
||||
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
|
||||
nullptr),
|
||||
raw_db_ptr);
|
||||
return raw_db_ptr;
|
||||
}(),
|
||||
&sqlite3_close)
|
||||
{
|
||||
|
||||
try
|
||||
{
|
||||
exec_direct("PRAGMA journal_mode = WAL");
|
||||
exec_direct("PRAGMA synchronous = NORMAL");
|
||||
exec_direct("PRAGMA foreign_keys = ON");
|
||||
|
||||
constexpr const char* schema = R"sql(
|
||||
CREATE TABLE IF NOT EXISTS gemm (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
rocm_version TEXT NOT NULL CHECK(length(rocm_version) > 0),
|
||||
commit_id TEXT NOT NULL CHECK(length(commit_id) > 0),
|
||||
device_name TEXT NOT NULL CHECK(length(device_name) > 0),
|
||||
instance_name TEXT NOT NULL CHECK(length(instance_name) > 0),
|
||||
problem TEXT NOT NULL CHECK(json_valid(problem)),
|
||||
latency REAL CHECK(latency > 0),
|
||||
tflops REAL CHECK(tflops > 0),
|
||||
bandwidth REAL CHECK(bandwidth > 0)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency);
|
||||
CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC);
|
||||
)sql";
|
||||
|
||||
exec_direct(schema);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
bool check_if_record(const KernelInstance& kernel_instance)
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
SELECT 1 FROM gemm
|
||||
WHERE rocm_version=? AND commit_id=? AND device_name=?
|
||||
AND instance_name=? AND problem=?
|
||||
LIMIT 1)sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
sqlite3_stmt* raw_stmt = stmt;
|
||||
int idx = 1;
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, kernel_instance.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, kernel_instance.env.commit_id.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, kernel_instance.env.device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, kernel_instance.name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
|
||||
idx++,
|
||||
kernel_instance.problem.serialize().c_str(),
|
||||
kernel_instance.problem.serialize().size(),
|
||||
SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
return (rc == SQLITE_ROW);
|
||||
}
|
||||
|
||||
PerformanceResult query_performance_result(const KernelInstance& kernel_instance)
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
SELECT latency, tflops, bandwidth FROM gemm
|
||||
WHERE rocm_version=? AND commit_id=? AND device_name=?
|
||||
AND instance_name=? AND problem=?
|
||||
LIMIT 1)sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
sqlite3_stmt* raw_stmt = stmt;
|
||||
|
||||
int idx = 1;
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, kernel_instance.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, kernel_instance.env.commit_id.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, kernel_instance.env.device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, kernel_instance.name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(stmt,
|
||||
idx++,
|
||||
kernel_instance.problem.serialize().c_str(),
|
||||
kernel_instance.problem.serialize().size(),
|
||||
SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
|
||||
if(rc == SQLITE_ROW)
|
||||
{
|
||||
return {sqlite3_column_double(raw_stmt, 0),
|
||||
sqlite3_column_double(raw_stmt, 1),
|
||||
sqlite3_column_double(raw_stmt, 2)};
|
||||
}
|
||||
else if(rc != SQLITE_DONE)
|
||||
{
|
||||
throw std::runtime_error(sqlite3_errmsg(db_ptr_.get()));
|
||||
}
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
|
||||
return {-1.0f, -1.0f, -1.0f};
|
||||
}
|
||||
|
||||
void insert_batch(const std::vector<KernelInstance>& data)
|
||||
{
|
||||
exec_direct("BEGIN TRANSACTION");
|
||||
try
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
INSERT INTO gemm
|
||||
(rocm_version, commit_id, device_name,
|
||||
instance_name, problem,
|
||||
latency, tflops, bandwidth)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8))sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
sqlite3_stmt* raw_stmt = stmt;
|
||||
|
||||
for(const auto& item : data)
|
||||
{
|
||||
int idx = 1;
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, item.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, item.env.commit_id.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, item.env.device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, item.name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
|
||||
idx++,
|
||||
item.problem.serialize().c_str(),
|
||||
item.problem.serialize().size(),
|
||||
SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.latency),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.tflops),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.bandwidth),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
}
|
||||
exec_direct("COMMIT");
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
exec_direct("ROLLBACK");
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
// std::vector<KernelInstance> query_top(Metric metric, int limit = 5)
|
||||
// {
|
||||
// if(limit <= 0)
|
||||
// throw invalid_argument("Limit must be positive");
|
||||
|
||||
// const char* order = nullptr;
|
||||
// switch(metric)
|
||||
// {
|
||||
// case LATENCY: order = "latency ASC"; break;
|
||||
// case TFLOPS: order = "tflops DESC"; break;
|
||||
// case BANDWIDTH: order = "bandwidth DESC"; break;
|
||||
// default: throw invalid_argument("Invalid metric");
|
||||
// }
|
||||
|
||||
// string sql = "SELECT name, latency, tflops, bandwidth FROM kernels "
|
||||
// "ORDER BY " +
|
||||
// string(order) + " LIMIT ?";
|
||||
|
||||
// StmtWrapper stmt(db, sql.c_str());
|
||||
// sqlite3_bind_int(stmt, 1, limit);
|
||||
|
||||
// std::vector<KernelInstance> results;
|
||||
// while(sqlite3_step(stmt) == SQLITE_ROW)
|
||||
// {
|
||||
// results.emplace_back(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)),
|
||||
// PerformanceResult{sqlite3_column_double(stmt, 1),
|
||||
// sqlite3_column_double(stmt, 2),
|
||||
// sqlite3_column_double(stmt, 3)});
|
||||
// }
|
||||
// return results;
|
||||
// }
|
||||
|
||||
private:
|
||||
void exec_direct(const char* sql)
|
||||
{
|
||||
CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get());
|
||||
}
|
||||
|
||||
std::unique_ptr<sqlite3, decltype(&sqlite3_close)> db_ptr_;
|
||||
};
|
||||
Reference in New Issue
Block a user