only support profile

This commit is contained in:
Yanxing-Shi
2025-05-01 11:05:27 +00:00
parent 82186ae503
commit d3d32843b5
7 changed files with 81 additions and 577 deletions

View File

@@ -1,36 +1,34 @@
# generate a list of kernels, but not actually emit files at config stage
# execute_process(
# COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
# --working_path ${CMAKE_CURRENT_BINARY_DIR}
# --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
# --list_blobs
# RESULT_VARIABLE ret
# )
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
--list_blobs
RESULT_VARIABLE ret
)
# if(ret AND NOT ret EQUAL 0)
# message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
# endif()
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
endif()
# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
# add_custom_command(
# OUTPUT ${GEMM_CODEGEN_BLOBS}
# COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
# --working_path ${CMAKE_CURRENT_BINARY_DIR}
# --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
# --gen_blobs
# DEPENDS ${GEMM_CODEGEN_BLOBS}
# )
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
--gen_blobs
DEPENDS ${GEMM_CODEGEN_BLOBS}
)
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
# use build as include directory
find_package(SQLite3 REQUIRED)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp)
target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_link_libraries(${EXECUTABLE_GEMM_INSTANCE} SQLite::SQLite3)
target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS})
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)

View File

@@ -28,10 +28,11 @@ make tile_engine_gemm -j
-stride_c Tensor C stride (default:0)
-split_k SplitK value (default:1)
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
-metric The metric value of kernel performance - latency: 0, tflops: 1, bandwidth: 2 (default:0)
-warmup Number of iterations before benchmark the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
-init Value for initializing tensor - random: 0, linear: 1, constant(1): 2 (default:0)
-pipeline possible values are: compv3, compv4, mem (default:compv3)
-scheduler possible values are: intrawave, interwave (default:intrawave)
-epilogue possible values are: cshuffle, default (default:cshuffle)

View File

@@ -8,17 +8,15 @@
#include <filesystem>
#include <memory>
#include "ck/version.h"
#include "ck/host_utility/device_prop.hpp"
#include "profile_cache.hpp"
class Executor
class Profiler
{
public:
~Executor() { kernel_instances_.clear(); }
static Executor& instance(bool enable_profile_cache = true, bool flush_profile_cache = false)
static Profiler& instance()
{
static Executor instance{enable_profile_cache, flush_profile_cache};
static Profiler instance;
return instance;
}
@@ -61,108 +59,39 @@ class Executor
KernelInstance kernel_instance{environment_, description, problem, {-1.0f, -1.0f, -1.0f}};
auto launch_kernel = [&] {
float avg_time = Kernel::launch(args, s);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
float avg_time = Kernel::launch(args, s);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_byte / 1.E6 / avg_time;
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_byte / 1.E6 / avg_time;
kernel_instance.perf_result.latency = avg_time;
kernel_instance.perf_result.tflops = tflops;
kernel_instance.perf_result.bandwidth = gb_per_sec;
kernel_instance.perf_result.latency = avg_time;
kernel_instance.perf_result.tflops = tflops;
kernel_instance.perf_result.bandwidth = gb_per_sec;
std::cout << kernel_instance << std::endl;
std::cout << kernel_instance << std::endl;
bool verified_correct =
!verify || compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
bool verified_correct =
!verify || compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << description << std::endl;
}
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
};
if(enable_profile_cache_)
if(verified_correct)
{
if(!cache_db_->check_if_record(kernel_instance))
{
launch_kernel();
cache_db_->insert_batch({kernel_instance});
}
else
{
auto perf_result = cache_db_->query_performance_result(kernel_instance);
kernel_instance.perf_result.latency = perf_result.latency;
kernel_instance.perf_result.tflops = perf_result.tflops;
kernel_instance.perf_result.bandwidth = perf_result.bandwidth;
std::cout << "Skip this kernel for " << description
<< ", Because it has already been recorded in the cache database"
<< std::endl;
kernel_instances_.emplace_back(kernel_instance);
}
kernel_instances_.emplace_back(kernel_instance);
}
else
{
launch_kernel();
std::cout << "Verification failed, skip kernel: " << description << std::endl;
}
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
void export_perf_to_csv(const std::vector<KernelInstance>& instances,
const std::string& filename)
{
std::ostringstream buffer;
buffer << "ROCmVersion,CommitID,DeviceName,KernelName,SplitK,M,N,K,"
<< "StrideA,StrideB,StrideC,ADataType,BDataType,AccDataType,"
<< "CDataType,ALayout,BLayout,CLayout,Latency(ms),TFLOPS,Bandwidth\n";
for(const auto& instance : instances)
{
const auto& env = instance.env;
const auto& p = instance.problem;
const auto& perf = instance.perf_result;
std::string sanitized_name = instance.name;
std::replace(sanitized_name.begin(), sanitized_name.end(), '\"', '\'');
buffer << env.rocm_version << "," << env.commit_id << "," << env.device_name << ","
<< "\"" << sanitized_name << "\"," << p.split_k << "," << p.m << "," << p.n
<< "," << p.k << "," << p.stride_a << "," << p.stride_b << "," << p.stride_c
<< "," << p.dtype_a << "," << p.dtype_b << "," << p.dtype_acc << "," << p.dtype_c
<< "," << p.layout_a << "," << p.layout_b << "," << p.layout_c << ","
<< std::fixed << std::setprecision(6) << perf.latency << "," << std::scientific
<< perf.tflops << "," << std::fixed << perf.bandwidth << "\n";
}
std::ofstream csv_file(filename, std::ios::trunc);
if(!csv_file)
{
throw std::runtime_error("Failed to open CSV file: " + filename);
}
csv_file << buffer.str();
csv_file.close();
if(csv_file.fail())
{
throw std::runtime_error("Incomplete write to CSV file: " + filename);
}
}
KernelInstance select_best_instance(Metric metric, const std::string& csv_path = "")
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
@@ -174,126 +103,28 @@ class Executor
b.perf_result, a.perf_result, metric);
});
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "The best kernel instance is: " << kernel_instance << std::endl;
if(!csv_path.empty())
{
try
{
export_perf_to_csv(kernel_instances_, csv_path);
}
catch(const std::exception& e)
{
std::cerr << "CSV export failed: " << e.what() << std::endl;
}
}
std::cout << "**********************************" << std::endl;
return kernel_instance;
}
private:
Executor(bool enable_profile_cache = true, bool flush_profile_cache = false)
: enable_profile_cache_(enable_profile_cache), flush_profile_cache_(flush_profile_cache)
Profiler()
{
environment_ = Environment{
get_rocm_version(),
"89f",
ck::get_device_name(),
};
std::cout << "Init gemm bechmark on device: " << environment_.device_name << std::endl;
initialize_profile_cache();
}
~Profiler() { kernel_instances_.clear(); }
void initialize_profile_cache()
{
// Init cache if enable profile cache
if(enable_profile_cache_)
{
// get profile cache path
std::filesystem::path cache_db_prefix_path =
std::filesystem::current_path() / ".tile_engine";
if(!create_cache_directory(cache_db_prefix_path))
{
std::cerr << "Error: Failed to create cache directory" << std::endl;
return;
}
std::filesystem::path cache_db_path =
cache_db_prefix_path / ("tile_engine_" + environment_.device_name + ".db");
// remove cache if flush_profile_cache
handle_cache_flush(cache_db_path);
// load profile cache
initialize_cache_db(cache_db_path);
}
else
{
std::cout << "Executor disable profile cache! " << std::endl;
}
}
bool create_cache_directory(const std::filesystem::path& cache_db_prefix_path)
{
std::error_code ec;
bool created = std::filesystem::create_directories(cache_db_prefix_path, ec);
if(ec)
{
std::cerr << "Error creating directory " << cache_db_prefix_path << ": " << ec.message()
<< std::endl;
return false;
}
if(created)
{
std::cout << "Created cache directory: " << cache_db_prefix_path << std::endl;
}
else
{
std::cout << "Using existing cache directory: " << cache_db_prefix_path << std::endl;
}
return true;
}
void handle_cache_flush(const std::filesystem::path& cache_db_path) const
{
if(flush_profile_cache_ && std::filesystem::exists(cache_db_path))
{
std::error_code ec;
if(std::filesystem::remove(cache_db_path, ec))
{
std::cout << "Successfully flushed cache: " << cache_db_path << std::endl;
}
else
{
std::cerr << "Error flushing cache: " << ec.message() << std::endl;
}
}
}
void initialize_cache_db(const std::filesystem::path& path)
{
try
{
cache_db_ = std::make_unique<ProfileCacheDB>(path);
std::cout << "Loaded profile cache from " << path << std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Failed to initialize profile cache: " << e.what() << std::endl;
}
}
Executor(const Executor&) = delete;
Executor& operator=(const Executor&) = delete;
Profiler(const Profiler&) = delete;
Profiler& operator=(const Profiler&) = delete;
Environment environment_;
bool enable_profile_cache_;
bool flush_profile_cache_;
std::unique_ptr<ProfileCacheDB> cache_db_;
std::vector<KernelInstance> kernel_instances_;
};

View File

@@ -11,22 +11,12 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
int metric,
bool enable_profile_cache,
bool flush_profile_cache,
KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s)
{
return GemmDispatcher::dispatch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
metric,
enable_profile_cache,
flush_profile_cache,
trait,
args,
s);
return GemmDispatcher::dispatch(
c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, metric, trait, args, s);
}
template <typename ADataType,
@@ -55,8 +45,6 @@ void run(const ck_tile::ArgParser& arg_parser)
int verify = arg_parser.get_int("v");
ck_tile::index_t init_method = arg_parser.get_int("init");
int metric = arg_parser.get_int("metric");
bool enable_profile_cache = arg_parser.get_bool("enable_profile_cache");
bool flush_profile_cache = arg_parser.get_bool("flush_profile_cache");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -168,8 +156,6 @@ void run(const ck_tile::ArgParser& arg_parser)
c_m_n_dev_result,
verify,
metric,
enable_profile_cache,
flush_profile_cache,
trait,
gemm_args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});

View File

@@ -94,7 +94,6 @@ struct PerformanceResult
struct Environment
{
std::string rocm_version;
std::string commit_id;
std::string device_name;
std::string serialize() const
@@ -102,7 +101,6 @@ struct Environment
std::ostringstream oss;
oss << "{"
<< "\"rocm_version\":\"" << rocm_version << "\","
<< "\"commit_id\":\"" << commit_id << "\","
<< "\"device_name\":\"" << device_name << "\""
<< "}";
return oss.str();
@@ -239,13 +237,7 @@ inline auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("metric", "latency", "latency, tflops, bandwidth")
.insert("enable_profile_cache",
"false",
"whether use profile cache or not when benchmark kernel")
.insert("flush_profile_cache",
"false",
"whether flush profile cache or not when benchmark kernel")
.insert("metric", "0", "0:latency, 1:tflops, 2:bandwidth")
.insert("pipeline", "compv3", "compv3, compv4, mem")
.insert("scheduler", "intrawave", "intrawave, interwave")
.insert("epilogue", "cshuffle", "cshuffle, default")

View File

@@ -476,21 +476,27 @@ struct GemmKernel {{
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_common.hpp"
#include "gemm_instances.hpp"
#include "gemm_host_api.hpp"
#include <unordered_map>
#include <functional>
#include <vector>
#include "gemm_common.hpp"
#include "gemm_instances.hpp"
#include "gemm_host_api.hpp"
#include "benchmark_gemm.hpp"
struct GemmDispatcher {
static auto& get_kernel_map() {
// Use a static local variable
static std::unordered_map<std::string,
std::function<void(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map;
static std::unordered_map<std::string,
std::function<void(Profiler&,
ck_tile::DeviceMem&,
ck_tile::HostTensor<CDataType>&,
ck_tile::HostTensor<CDataType>&,
int,
ck_tile::GemmHostArgs&,
const ck_tile::stream_config&)>>
kernel_map;
return kernel_map;
}
@@ -513,7 +519,8 @@ struct GemmDispatcher {
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf,
content += f""" kernel_map["{group}"] = [](Profiler& profiler,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args,
@@ -526,46 +533,29 @@ struct GemmDispatcher {
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
profiler.benchmark_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);"""
content += f"""
}};\n"""
content += """ }
template <typename Kernel>
static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
float avg_time = Kernel::launch(args, s);
std::string description = Kernel::get_name();
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_byte / 1.E6 / avg_time;
std::cout << "Performance for " << description << " : " << avg_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
if(verify)
compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result);
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args,
int verify,
int metric,
const KernelTraits& trait,
ck_tile::GemmHostArgs& gemm_args,
const ck_tile::stream_config& s) {
init();
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
auto& profiler = Profiler::instance();
if(auto it = kernel_map.find(key); it != kernel_map.end()) {
return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s);
it->second(
profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, s);
profiler.select_best_instance(static_cast<Metric>(metric));
return;
}
throw std::runtime_error("No suitable kernel found: " + key);
}

View File

@@ -1,294 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <sqlite3.h>
#include "gemm_host_api.hpp"
#define CHECK_SQLITE3(expr, db) \
do \
{ \
int result_code = (expr); \
if(result_code != SQLITE_OK) \
{ \
const char* err = sqlite3_errmsg(db); \
throw std::runtime_error("SQLite error[" + std::to_string(result_code) + \
"]: " + (err ? err : "unknown error") + " at " + \
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
} \
} while(0)
#define CHECK_SQLITE3_RC(expr, db, rc) \
do \
{ \
rc = (expr); \
if(rc != SQLITE_OK && rc != SQLITE_ROW && rc != SQLITE_DONE) \
{ \
const char* err = sqlite3_errmsg(db); \
throw std::runtime_error("SQLite error[" + std::to_string(rc) + \
"]: " + (err ? err : "unknown error") + " at " + \
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
} \
} while(0)
class StmtWrapper
{
public:
explicit StmtWrapper(sqlite3* db, const char* sql)
: stmt_(
[db, sql] {
sqlite3_stmt* stmt = nullptr;
CHECK_SQLITE3(sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr), db);
return stmt;
}(),
&sqlite3_finalize)
{
}
operator sqlite3_stmt*() const { return stmt_.get(); }
private:
std::unique_ptr<sqlite3_stmt, decltype(&sqlite3_finalize)> stmt_;
};
class ProfileCacheDB
{
public:
explicit ProfileCacheDB(const std::filesystem::path& path)
: db_ptr_(
[path] {
sqlite3* raw_db_ptr = nullptr;
CHECK_SQLITE3(sqlite3_open_v2(path.string().c_str(),
&raw_db_ptr,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
nullptr),
raw_db_ptr);
return raw_db_ptr;
}(),
&sqlite3_close)
{
try
{
exec_direct("PRAGMA journal_mode = WAL");
exec_direct("PRAGMA synchronous = NORMAL");
exec_direct("PRAGMA foreign_keys = ON");
constexpr const char* schema = R"sql(
CREATE TABLE IF NOT EXISTS gemm (
id INTEGER PRIMARY KEY AUTOINCREMENT,
rocm_version TEXT NOT NULL CHECK(length(rocm_version) > 0),
commit_id TEXT NOT NULL CHECK(length(commit_id) > 0),
device_name TEXT NOT NULL CHECK(length(device_name) > 0),
instance_name TEXT NOT NULL CHECK(length(instance_name) > 0),
problem TEXT NOT NULL CHECK(json_valid(problem)),
latency REAL CHECK(latency > 0),
tflops REAL CHECK(tflops > 0),
bandwidth REAL CHECK(bandwidth > 0)
);
CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency);
CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC);
CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC);
)sql";
exec_direct(schema);
}
catch(...)
{
throw;
}
}
bool check_if_record(const KernelInstance& kernel_instance)
{
constexpr const char* sql = R"sql(
SELECT 1 FROM gemm
WHERE rocm_version=? AND commit_id=? AND device_name=?
AND instance_name=? AND problem=?
LIMIT 1)sql";
StmtWrapper stmt(db_ptr_.get(), sql);
sqlite3_stmt* raw_stmt = stmt;
int idx = 1;
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, kernel_instance.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, kernel_instance.env.commit_id.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, kernel_instance.env.device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, kernel_instance.name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
idx++,
kernel_instance.problem.serialize().c_str(),
kernel_instance.problem.serialize().size(),
SQLITE_TRANSIENT),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
return (rc == SQLITE_ROW);
}
PerformanceResult query_performance_result(const KernelInstance& kernel_instance)
{
constexpr const char* sql = R"sql(
SELECT latency, tflops, bandwidth FROM gemm
WHERE rocm_version=? AND commit_id=? AND device_name=?
AND instance_name=? AND problem=?
LIMIT 1)sql";
StmtWrapper stmt(db_ptr_.get(), sql);
sqlite3_stmt* raw_stmt = stmt;
int idx = 1;
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, kernel_instance.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, kernel_instance.env.commit_id.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, kernel_instance.env.device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, kernel_instance.name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(stmt,
idx++,
kernel_instance.problem.serialize().c_str(),
kernel_instance.problem.serialize().size(),
SQLITE_TRANSIENT),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
if(rc == SQLITE_ROW)
{
return {sqlite3_column_double(raw_stmt, 0),
sqlite3_column_double(raw_stmt, 1),
sqlite3_column_double(raw_stmt, 2)};
}
else if(rc != SQLITE_DONE)
{
throw std::runtime_error(sqlite3_errmsg(db_ptr_.get()));
}
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
return {-1.0f, -1.0f, -1.0f};
}
void insert_batch(const std::vector<KernelInstance>& data)
{
exec_direct("BEGIN TRANSACTION");
try
{
constexpr const char* sql = R"sql(
INSERT INTO gemm
(rocm_version, commit_id, device_name,
instance_name, problem,
latency, tflops, bandwidth)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8))sql";
StmtWrapper stmt(db_ptr_.get(), sql);
sqlite3_stmt* raw_stmt = stmt;
for(const auto& item : data)
{
int idx = 1;
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, item.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, item.env.commit_id.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, item.env.device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, item.name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
idx++,
item.problem.serialize().c_str(),
item.problem.serialize().size(),
SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.latency),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.tflops),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.bandwidth),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
}
exec_direct("COMMIT");
}
catch(...)
{
exec_direct("ROLLBACK");
throw;
}
}
// std::vector<KernelInstance> query_top(Metric metric, int limit = 5)
// {
// if(limit <= 0)
// throw invalid_argument("Limit must be positive");
// const char* order = nullptr;
// switch(metric)
// {
// case LATENCY: order = "latency ASC"; break;
// case TFLOPS: order = "tflops DESC"; break;
// case BANDWIDTH: order = "bandwidth DESC"; break;
// default: throw invalid_argument("Invalid metric");
// }
// string sql = "SELECT name, latency, tflops, bandwidth FROM kernels "
// "ORDER BY " +
// string(order) + " LIMIT ?";
// StmtWrapper stmt(db, sql.c_str());
// sqlite3_bind_int(stmt, 1, limit);
// std::vector<KernelInstance> results;
// while(sqlite3_step(stmt) == SQLITE_ROW)
// {
// results.emplace_back(reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)),
// PerformanceResult{sqlite3_column_double(stmt, 1),
// sqlite3_column_double(stmt, 2),
// sqlite3_column_double(stmt, 3)});
// }
// return results;
// }
private:
void exec_direct(const char* sql)
{
CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get());
}
std::unique_ptr<sqlite3, decltype(&sqlite3_close)> db_ptr_;
};