mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
262 lines
11 KiB
C++
262 lines
11 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <sqlite3.h>
|
|
#include <tuple>
|
|
#include <sstream>
|
|
|
|
#include "benchmark_gemm.hpp"
|
|
|
|
#define CHECK_SQLITE3(expr, db) \
|
|
do \
|
|
{ \
|
|
int result_code = (expr); \
|
|
if(result_code != SQLITE_OK) \
|
|
{ \
|
|
const char* err = sqlite3_errmsg(db); \
|
|
throw std::runtime_error("SQLite error[" + std::to_string(result_code) + \
|
|
"]: " + (err ? err : "unknown error") + " at " + \
|
|
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
|
|
} \
|
|
} while(0)
|
|
|
|
#define CHECK_SQLITE3_RC(expr, db, rc) \
|
|
do \
|
|
{ \
|
|
rc = (expr); \
|
|
if(rc != SQLITE_OK && rc != SQLITE_ROW && rc != SQLITE_DONE) \
|
|
{ \
|
|
const char* err = sqlite3_errmsg(db); \
|
|
throw std::runtime_error("SQLite error[" + std::to_string(rc) + \
|
|
"]: " + (err ? err : "unknown error") + " at " + \
|
|
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
|
|
} \
|
|
} while(0)
|
|
|
|
class StmtWrapper
|
|
{
|
|
public:
|
|
explicit StmtWrapper(sqlite3* db, const char* sql)
|
|
: stmt_(
|
|
[db, sql] {
|
|
sqlite3_stmt* stmt = nullptr;
|
|
CHECK_SQLITE3(sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr), db);
|
|
return stmt;
|
|
}(),
|
|
&sqlite3_finalize)
|
|
{
|
|
}
|
|
|
|
operator sqlite3_stmt*() const { return stmt_.get(); }
|
|
|
|
private:
|
|
std::unique_ptr<sqlite3_stmt, decltype(&sqlite3_finalize)> stmt_;
|
|
};
|
|
|
|
class ProfileCacheDB
|
|
{
|
|
public:
|
|
explicit ProfileCacheDB(const std::filesystem::path& path)
|
|
: db_ptr_(
|
|
[path] {
|
|
sqlite3* raw_db_ptr = nullptr;
|
|
CHECK_SQLITE3(sqlite3_open_v2(path.string().c_str(),
|
|
&raw_db_ptr,
|
|
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
|
|
nullptr),
|
|
raw_db_ptr);
|
|
return raw_db_ptr;
|
|
}(),
|
|
&sqlite3_close)
|
|
{
|
|
|
|
try
|
|
{
|
|
exec_direct("PRAGMA journal_mode = WAL");
|
|
exec_direct("PRAGMA synchronous = NORMAL");
|
|
exec_direct("PRAGMA foreign_keys = ON");
|
|
|
|
constexpr const char* schema = R"sql(
|
|
CREATE TABLE IF NOT EXISTS gemm (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
rocm_version TEXT NOT NULL CHECK(length(rocm_version) > 0),
|
|
device_name TEXT NOT NULL CHECK(length(device_name) > 0),
|
|
problem TEXT NOT NULL CHECK(json_valid(problem)),
|
|
instance_name TEXT NOT NULL CHECK(length(instance_name) > 0),
|
|
latency REAL CHECK(latency > 0),
|
|
tflops REAL CHECK(tflops > 0),
|
|
bandwidth REAL CHECK(bandwidth > 0)
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency);
|
|
CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC);
|
|
)sql";
|
|
|
|
exec_direct(schema);
|
|
}
|
|
catch(...)
|
|
{
|
|
throw;
|
|
}
|
|
}
|
|
|
|
bool check_if_record_problem(const std::string& rocm_version,
|
|
const std::string& device_name,
|
|
const GemmProblem& gemm_problem)
|
|
{
|
|
constexpr const char* sql = R"sql(
|
|
SELECT 1 FROM gemm
|
|
WHERE rocm_version=?
|
|
AND device_name=?
|
|
AND problem=?
|
|
LIMIT 1
|
|
)sql";
|
|
|
|
StmtWrapper stmt(db_ptr_.get(), sql);
|
|
sqlite3_stmt* raw_stmt = stmt;
|
|
int idx = 1;
|
|
CHECK_SQLITE3(
|
|
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
std::ostringstream oss;
|
|
oss << gemm_problem;
|
|
auto problem_json = oss.str();
|
|
CHECK_SQLITE3(
|
|
sqlite3_bind_text(raw_stmt, idx++, problem_json.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
|
|
int rc;
|
|
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
|
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
|
|
|
std::cout << "Query params:\n"
|
|
<< "rocm_version: " << rocm_version << "\n"
|
|
<< "device_name: " << device_name << "\n"
|
|
<< "problem_json: " << problem_json << std::endl;
|
|
|
|
if(rc == SQLITE_DONE)
|
|
{
|
|
std::cout << "No matching records found" << std::endl;
|
|
}
|
|
return (rc == SQLITE_ROW);
|
|
}
|
|
|
|
std::tuple<std::string, PerformanceResult> query_cache(const std::string& rocm_version,
|
|
const std::string& device_name,
|
|
const GemmProblem& gemm_problem)
|
|
{
|
|
constexpr const char* sql = R"sql(
|
|
SELECT latency, tflops, bandwidth FROM gemm
|
|
WHERE rocm_version=? AND device_name=?
|
|
AND problem=?
|
|
LIMIT 1
|
|
)sql";
|
|
|
|
StmtWrapper stmt(db_ptr_.get(), sql);
|
|
sqlite3_stmt* raw_stmt = stmt;
|
|
|
|
int idx = 1;
|
|
CHECK_SQLITE3(
|
|
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
std::ostringstream oss;
|
|
oss << gemm_problem;
|
|
auto problem_json = oss.str();
|
|
CHECK_SQLITE3(sqlite3_bind_text(
|
|
stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
|
|
int rc;
|
|
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
|
|
|
if(rc == SQLITE_ROW)
|
|
{
|
|
return std::make_tuple(reinterpret_cast<const char*>(sqlite3_column_text(raw_stmt, 0)),
|
|
PerformanceResult{sqlite3_column_double(raw_stmt, 1),
|
|
sqlite3_column_double(raw_stmt, 2),
|
|
sqlite3_column_double(raw_stmt, 3)});
|
|
}
|
|
else if(rc != SQLITE_DONE)
|
|
{
|
|
throw std::runtime_error(sqlite3_errmsg(db_ptr_.get()));
|
|
}
|
|
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
|
|
|
return std::make_tuple("", PerformanceResult{-1.0f, -1.0f, -1.0f});
|
|
}
|
|
|
|
void insert_cache(const std::string& rocm_version,
|
|
const std::string& device_name,
|
|
const std::vector<KernelInstance>& kernen_instnaces)
|
|
{
|
|
exec_direct("BEGIN TRANSACTION");
|
|
try
|
|
{
|
|
constexpr const char* sql = R"sql(
|
|
INSERT INTO gemm
|
|
(rocm_version, device_name,
|
|
problem, instance_name,
|
|
latency, tflops, bandwidth)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
|
)sql";
|
|
|
|
StmtWrapper stmt(db_ptr_.get(), sql);
|
|
sqlite3_stmt* raw_stmt = stmt;
|
|
|
|
for(const auto& item : kernen_instnaces)
|
|
{
|
|
int idx = 1;
|
|
CHECK_SQLITE3(
|
|
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
CHECK_SQLITE3(
|
|
sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
std::ostringstream oss;
|
|
oss << item.problem_;
|
|
auto problem_json = oss.str();
|
|
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
|
|
idx++,
|
|
problem_json.c_str(),
|
|
problem_json.size(),
|
|
SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
CHECK_SQLITE3(
|
|
sqlite3_bind_text(raw_stmt, idx++, item.name_.c_str(), -1, SQLITE_TRANSIENT),
|
|
db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.latency_),
|
|
db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.tflops_),
|
|
db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.bandwidth_),
|
|
db_ptr_.get());
|
|
|
|
int rc;
|
|
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
|
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
|
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
|
}
|
|
exec_direct("COMMIT");
|
|
}
|
|
catch(...)
|
|
{
|
|
exec_direct("ROLLBACK");
|
|
throw;
|
|
}
|
|
}
|
|
|
|
private:
|
|
void exec_direct(const char* sql)
|
|
{
|
|
CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get());
|
|
}
|
|
|
|
std::unique_ptr<sqlite3, decltype(&sqlite3_close)> db_ptr_;
|
|
};
|