initial commit, but resul=0 bug

This commit is contained in:
Yanxing-Shi
2025-05-22 10:11:05 +00:00
parent 40cd09a93d
commit ecf403a430
6 changed files with 394 additions and 7 deletions

View File

@@ -1,9 +1,11 @@
find_package(SQLite3 REQUIRED)
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
@@ -31,7 +33,7 @@ add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)
@@ -51,8 +53,8 @@ target_link_libraries(gemm_host_api INTERFACE gemm_template_instances)
add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp)
target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp)
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api)
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp profile_cache_db.hpp)
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api SQLite::SQLite3)
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)

View File

@@ -27,6 +27,8 @@ void benchmark_gemm(const ck_tile::ArgParser& arg_parser)
arg_parser.get_bool("structured_sparsity")};
Setting setting{
arg_parser.get_bool("enable_profile_cache"),
arg_parser.get_bool("flush_profile_cache"),
arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),

View File

@@ -118,6 +118,8 @@ struct KernelInstance
struct Setting
{
bool enable_profile_cache_;
bool flush_profile_cache_;
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;

View File

@@ -73,6 +73,14 @@ inline auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_c", "0", "The stride value for tensor C Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("enable_profile_cache",
"true",
"whether use profile cache or not when benchmark kernel, Possible values are true "
"or false. Default is true")
.insert("flush_profile_cache",
"false",
"whether flush profile cache or not when benchmark kernel. Possible values are "
"true or false. Default is false")
.insert("verify",
"2",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "

View File

@@ -6,10 +6,12 @@
#include <iostream>
#include <fstream>
#include <iomanip>
#include <filesystem>
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "benchmark_gemm.hpp"
#include "profile_cache_db.hpp"
class GemmProfiler
{
@@ -20,10 +22,45 @@ class GemmProfiler
return instance;
}
bool is_problem_record_cache(const GemmProblem& gemm_problem)
{
if(setting_.enable_profile_cache_)
{
if(!cache_db_->check_if_record_problem(
get_rocm_version(), ck_tile::get_device_name(), gemm_problem))
{
return false;
}
else
{
auto [name, perf_result] = cache_db_->query_cache(
get_rocm_version(), ck_tile::get_device_name(), gemm_problem);
KernelInstance kernel_instance;
kernel_instance.problem_ = gemm_problem;
kernel_instance.name_ = name;
kernel_instance.perf_result_.latency_ = perf_result.latency_;
kernel_instance.perf_result_.tflops_ = perf_result.tflops_;
kernel_instance.perf_result_.bandwidth_ = perf_result.bandwidth_;
std::cout << "Skip this problem for " << gemm_problem
<< ", Because it has already been recorded in the cache database"
<< std::endl;
kernel_instances_.emplace_back(kernel_instance);
return true;
}
}
else
{
return false;
}
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
{
if(is_problem_record_cache(gemm_problem))
return;
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
@@ -135,6 +172,12 @@ class GemmProfiler
c_m_n_dev_result,
kernel_run_result);
}
if(setting_.enable_profile_cache_)
{
cache_db_->insert_cache(
get_rocm_version(), ck_tile::get_device_name(), kernel_instances_);
}
}
void process_result(const GemmProblem& gemm_problem,
@@ -247,14 +290,92 @@ class GemmProfiler
return kernel_instance;
}
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler& operator=(const GemmProfiler&) = delete;
private:
~GemmProfiler() { kernel_instances_.clear(); }
GemmProfiler(Setting setting) : setting_(setting) {}
GemmProfiler(Setting setting) : setting_(setting) { initialize_profile_cache(); }
void initialize_profile_cache()
{
if(setting_.enable_profile_cache_)
{
std::filesystem::path cache_db_prefix_path =
std::filesystem::current_path() / ".tile_engine";
if(!create_cache_directory(cache_db_prefix_path))
{
std::cerr << "Error: Failed to create cache directory" << std::endl;
return;
}
std::filesystem::path cache_db_path =
cache_db_prefix_path / ("tile_engine_" + ck_tile::get_device_name() + ".db");
handle_flush_cache(cache_db_path);
initialize_cache_db(cache_db_path);
}
else
{
std::cout << "Gemm profiler disable profile cache! " << std::endl;
}
}
bool create_cache_directory(const std::filesystem::path& cache_db_prefix_path)
{
std::error_code ec;
bool created = std::filesystem::create_directories(cache_db_prefix_path, ec);
if(ec)
{
std::cerr << "Error creating directory " << cache_db_prefix_path << ": " << ec.message()
<< std::endl;
return false;
}
if(created)
{
std::cout << "Created cache directory: " << cache_db_prefix_path << std::endl;
}
else
{
std::cout << "Using existing cache directory: " << cache_db_prefix_path << std::endl;
}
return true;
}
void handle_flush_cache(const std::filesystem::path& cache_db_path) const
{
if(setting_.flush_profile_cache_ && std::filesystem::exists(cache_db_path))
{
std::error_code ec;
if(std::filesystem::remove(cache_db_path, ec))
{
std::cout << "Successfully flushed cache: " << cache_db_path << std::endl;
}
else
{
std::cerr << "Error flushing cache: " << ec.message() << std::endl;
}
}
}
void initialize_cache_db(const std::filesystem::path& path)
{
try
{
cache_db_ = std::make_unique<ProfileCacheDB>(path);
std::cout << "Loaded profile cache from " << path << std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Failed to initialize profile cache: " << e.what() << std::endl;
}
}
Setting setting_;
std::unique_ptr<ProfileCacheDB> cache_db_;
std::vector<KernelInstance> kernel_instances_;
};

View File

@@ -0,0 +1,252 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <sqlite3.h>
#include <tuple>
#include <sstream>
#include "benchmark_gemm.hpp"
#define CHECK_SQLITE3(expr, db) \
do \
{ \
int result_code = (expr); \
if(result_code != SQLITE_OK) \
{ \
const char* err = sqlite3_errmsg(db); \
throw std::runtime_error("SQLite error[" + std::to_string(result_code) + \
"]: " + (err ? err : "unknown error") + " at " + \
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
} \
} while(0)
#define CHECK_SQLITE3_RC(expr, db, rc) \
do \
{ \
rc = (expr); \
if(rc != SQLITE_OK && rc != SQLITE_ROW && rc != SQLITE_DONE) \
{ \
const char* err = sqlite3_errmsg(db); \
throw std::runtime_error("SQLite error[" + std::to_string(rc) + \
"]: " + (err ? err : "unknown error") + " at " + \
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
} \
} while(0)
class StmtWrapper
{
public:
explicit StmtWrapper(sqlite3* db, const char* sql)
: stmt_(
[db, sql] {
sqlite3_stmt* stmt = nullptr;
CHECK_SQLITE3(sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr), db);
return stmt;
}(),
&sqlite3_finalize)
{
}
operator sqlite3_stmt*() const { return stmt_.get(); }
private:
std::unique_ptr<sqlite3_stmt, decltype(&sqlite3_finalize)> stmt_;
};
class ProfileCacheDB
{
public:
explicit ProfileCacheDB(const std::filesystem::path& path)
: db_ptr_(
[path] {
sqlite3* raw_db_ptr = nullptr;
CHECK_SQLITE3(sqlite3_open_v2(path.string().c_str(),
&raw_db_ptr,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
nullptr),
raw_db_ptr);
return raw_db_ptr;
}(),
&sqlite3_close)
{
try
{
exec_direct("PRAGMA journal_mode = WAL");
exec_direct("PRAGMA synchronous = NORMAL");
exec_direct("PRAGMA foreign_keys = ON");
constexpr const char* schema = R"sql(
CREATE TABLE IF NOT EXISTS gemm (
id INTEGER PRIMARY KEY AUTOINCREMENT,
rocm_version TEXT NOT NULL CHECK(length(rocm_version) > 0),
device_name TEXT NOT NULL CHECK(length(device_name) > 0),
problem TEXT NOT NULL CHECK(json_valid(problem)),
instance_name TEXT NOT NULL CHECK(length(instance_name) > 0),
latency REAL CHECK(latency > 0),
tflops REAL CHECK(tflops > 0),
bandwidth REAL CHECK(bandwidth > 0)
);
CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency);
CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC);
CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC);
)sql";
exec_direct(schema);
}
catch(...)
{
throw;
}
}
bool check_if_record_problem(const std::string& rocm_version,
const std::string& device_name,
const GemmProblem& gemm_problem)
{
constexpr const char* sql = R"sql(
SELECT 1 FROM gemm
WHERE rocm_version=?
AND device_name=?
AND problem=?
LIMIT 1
)sql";
StmtWrapper stmt(db_ptr_.get(), sql);
sqlite3_stmt* raw_stmt = stmt;
int idx = 1;
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
std::ostringstream oss;
oss << gemm_problem;
auto problem_json = oss.str();
CHECK_SQLITE3(
sqlite3_bind_text(
raw_stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
return (rc == SQLITE_ROW);
}
std::tuple<std::string, PerformanceResult> query_cache(const std::string& rocm_version,
const std::string& device_name,
const GemmProblem& gemm_problem)
{
constexpr const char* sql = R"sql(
SELECT latency, tflops, bandwidth FROM gemm
WHERE rocm_version=? AND device_name=?
AND problem=?
LIMIT 1
)sql";
StmtWrapper stmt(db_ptr_.get(), sql);
sqlite3_stmt* raw_stmt = stmt;
int idx = 1;
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
std::ostringstream oss;
oss << gemm_problem;
auto problem_json = oss.str();
CHECK_SQLITE3(sqlite3_bind_text(
stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
if(rc == SQLITE_ROW)
{
return std::make_tuple(reinterpret_cast<const char*>(sqlite3_column_text(raw_stmt, 0)),
PerformanceResult{sqlite3_column_double(raw_stmt, 1),
sqlite3_column_double(raw_stmt, 2),
sqlite3_column_double(raw_stmt, 3)});
}
else if(rc != SQLITE_DONE)
{
throw std::runtime_error(sqlite3_errmsg(db_ptr_.get()));
}
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
return std::make_tuple("", PerformanceResult{-1.0f, -1.0f, -1.0f});
}
void insert_cache(const std::string& rocm_version,
const std::string& device_name,
const std::vector<KernelInstance>& kernen_instnaces)
{
exec_direct("BEGIN TRANSACTION");
try
{
constexpr const char* sql = R"sql(
INSERT INTO gemm
(rocm_version, device_name,
problem, instance_name,
latency, tflops, bandwidth)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
)sql";
StmtWrapper stmt(db_ptr_.get(), sql);
sqlite3_stmt* raw_stmt = stmt;
for(const auto& item : kernen_instnaces)
{
int idx = 1;
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
std::ostringstream oss;
oss << item.problem_;
auto problem_json = oss.str();
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
idx++,
problem_json.c_str(),
problem_json.size(),
SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(
sqlite3_bind_text(raw_stmt, idx++, item.name_.c_str(), -1, SQLITE_TRANSIENT),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.latency_),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.tflops_),
db_ptr_.get());
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.bandwidth_),
db_ptr_.get());
int rc;
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
}
exec_direct("COMMIT");
}
catch(...)
{
exec_direct("ROLLBACK");
throw;
}
}
private:
void exec_direct(const char* sql)
{
CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get());
}
std::unique_ptr<sqlite3, decltype(&sqlite3_close)> db_ptr_;
};