mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
initial commit, but resul=0 bug
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
|
||||
find_package(SQLite3 REQUIRED)
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
@@ -31,7 +33,7 @@ add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--gen_blobs
|
||||
)
|
||||
|
||||
@@ -51,8 +53,8 @@ target_link_libraries(gemm_host_api INTERFACE gemm_template_instances)
|
||||
|
||||
add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp)
|
||||
target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp)
|
||||
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api)
|
||||
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp profile_cache_db.hpp)
|
||||
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api SQLite::SQLite3)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ void benchmark_gemm(const ck_tile::ArgParser& arg_parser)
|
||||
arg_parser.get_bool("structured_sparsity")};
|
||||
|
||||
Setting setting{
|
||||
arg_parser.get_bool("enable_profile_cache"),
|
||||
arg_parser.get_bool("flush_profile_cache"),
|
||||
arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
|
||||
@@ -118,6 +118,8 @@ struct KernelInstance
|
||||
|
||||
struct Setting
|
||||
{
|
||||
bool enable_profile_cache_;
|
||||
bool flush_profile_cache_;
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
|
||||
@@ -73,6 +73,14 @@ inline auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_c", "0", "The stride value for tensor C Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("enable_profile_cache",
|
||||
"true",
|
||||
"whether use profile cache or not when benchmark kernel, Possible values are true "
|
||||
"or false. Default is true")
|
||||
.insert("flush_profile_cache",
|
||||
"false",
|
||||
"whether flush profile cache or not when benchmark kernel. Possible values are "
|
||||
"true or false. Default is false")
|
||||
.insert("verify",
|
||||
"2",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <filesystem>
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm.hpp"
|
||||
#include "profile_cache_db.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
@@ -20,10 +22,45 @@ class GemmProfiler
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool is_problem_record_cache(const GemmProblem& gemm_problem)
|
||||
{
|
||||
if(setting_.enable_profile_cache_)
|
||||
{
|
||||
if(!cache_db_->check_if_record_problem(
|
||||
get_rocm_version(), ck_tile::get_device_name(), gemm_problem))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto [name, perf_result] = cache_db_->query_cache(
|
||||
get_rocm_version(), ck_tile::get_device_name(), gemm_problem);
|
||||
KernelInstance kernel_instance;
|
||||
kernel_instance.problem_ = gemm_problem;
|
||||
kernel_instance.name_ = name;
|
||||
kernel_instance.perf_result_.latency_ = perf_result.latency_;
|
||||
kernel_instance.perf_result_.tflops_ = perf_result.tflops_;
|
||||
kernel_instance.perf_result_.bandwidth_ = perf_result.bandwidth_;
|
||||
std::cout << "Skip this problem for " << gemm_problem
|
||||
<< ", Because it has already been recorded in the cache database"
|
||||
<< std::endl;
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
if(is_problem_record_cache(gemm_problem))
|
||||
return;
|
||||
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
@@ -135,6 +172,12 @@ class GemmProfiler
|
||||
c_m_n_dev_result,
|
||||
kernel_run_result);
|
||||
}
|
||||
|
||||
if(setting_.enable_profile_cache_)
|
||||
{
|
||||
cache_db_->insert_cache(
|
||||
get_rocm_version(), ck_tile::get_device_name(), kernel_instances_);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmProblem& gemm_problem,
|
||||
@@ -247,14 +290,92 @@ class GemmProfiler
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GemmProfiler() { kernel_instances_.clear(); }
|
||||
GemmProfiler(Setting setting) : setting_(setting) {}
|
||||
GemmProfiler(Setting setting) : setting_(setting) { initialize_profile_cache(); }
|
||||
|
||||
void initialize_profile_cache()
|
||||
{
|
||||
if(setting_.enable_profile_cache_)
|
||||
{
|
||||
std::filesystem::path cache_db_prefix_path =
|
||||
std::filesystem::current_path() / ".tile_engine";
|
||||
if(!create_cache_directory(cache_db_prefix_path))
|
||||
{
|
||||
std::cerr << "Error: Failed to create cache directory" << std::endl;
|
||||
return;
|
||||
}
|
||||
std::filesystem::path cache_db_path =
|
||||
cache_db_prefix_path / ("tile_engine_" + ck_tile::get_device_name() + ".db");
|
||||
|
||||
handle_flush_cache(cache_db_path);
|
||||
|
||||
initialize_cache_db(cache_db_path);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Gemm profiler disable profile cache! " << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool create_cache_directory(const std::filesystem::path& cache_db_prefix_path)
|
||||
{
|
||||
std::error_code ec;
|
||||
bool created = std::filesystem::create_directories(cache_db_prefix_path, ec);
|
||||
|
||||
if(ec)
|
||||
{
|
||||
std::cerr << "Error creating directory " << cache_db_prefix_path << ": " << ec.message()
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(created)
|
||||
|
||||
{
|
||||
std::cout << "Created cache directory: " << cache_db_prefix_path << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Using existing cache directory: " << cache_db_prefix_path << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void handle_flush_cache(const std::filesystem::path& cache_db_path) const
|
||||
{
|
||||
if(setting_.flush_profile_cache_ && std::filesystem::exists(cache_db_path))
|
||||
{
|
||||
std::error_code ec;
|
||||
if(std::filesystem::remove(cache_db_path, ec))
|
||||
|
||||
{
|
||||
std::cout << "Successfully flushed cache: " << cache_db_path << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Error flushing cache: " << ec.message() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_cache_db(const std::filesystem::path& path)
|
||||
{
|
||||
try
|
||||
{
|
||||
cache_db_ = std::make_unique<ProfileCacheDB>(path);
|
||||
std::cout << "Loaded profile cache from " << path << std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Failed to initialize profile cache: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::unique_ptr<ProfileCacheDB> cache_db_;
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
|
||||
252
tile_engine/ops/gemm/profile_cache_db.hpp
Normal file
252
tile_engine/ops/gemm/profile_cache_db.hpp
Normal file
@@ -0,0 +1,252 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <sqlite3.h>
|
||||
#include <tuple>
|
||||
#include <sstream>
|
||||
|
||||
#include "benchmark_gemm.hpp"
|
||||
|
||||
#define CHECK_SQLITE3(expr, db) \
|
||||
do \
|
||||
{ \
|
||||
int result_code = (expr); \
|
||||
if(result_code != SQLITE_OK) \
|
||||
{ \
|
||||
const char* err = sqlite3_errmsg(db); \
|
||||
throw std::runtime_error("SQLite error[" + std::to_string(result_code) + \
|
||||
"]: " + (err ? err : "unknown error") + " at " + \
|
||||
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define CHECK_SQLITE3_RC(expr, db, rc) \
|
||||
do \
|
||||
{ \
|
||||
rc = (expr); \
|
||||
if(rc != SQLITE_OK && rc != SQLITE_ROW && rc != SQLITE_DONE) \
|
||||
{ \
|
||||
const char* err = sqlite3_errmsg(db); \
|
||||
throw std::runtime_error("SQLite error[" + std::to_string(rc) + \
|
||||
"]: " + (err ? err : "unknown error") + " at " + \
|
||||
std::string(__FILE__) + ":" + std::to_string(__LINE__)); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
class StmtWrapper
|
||||
{
|
||||
public:
|
||||
explicit StmtWrapper(sqlite3* db, const char* sql)
|
||||
: stmt_(
|
||||
[db, sql] {
|
||||
sqlite3_stmt* stmt = nullptr;
|
||||
CHECK_SQLITE3(sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr), db);
|
||||
return stmt;
|
||||
}(),
|
||||
&sqlite3_finalize)
|
||||
{
|
||||
}
|
||||
|
||||
operator sqlite3_stmt*() const { return stmt_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<sqlite3_stmt, decltype(&sqlite3_finalize)> stmt_;
|
||||
};
|
||||
|
||||
class ProfileCacheDB
|
||||
{
|
||||
public:
|
||||
explicit ProfileCacheDB(const std::filesystem::path& path)
|
||||
: db_ptr_(
|
||||
[path] {
|
||||
sqlite3* raw_db_ptr = nullptr;
|
||||
CHECK_SQLITE3(sqlite3_open_v2(path.string().c_str(),
|
||||
&raw_db_ptr,
|
||||
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
|
||||
nullptr),
|
||||
raw_db_ptr);
|
||||
return raw_db_ptr;
|
||||
}(),
|
||||
&sqlite3_close)
|
||||
{
|
||||
|
||||
try
|
||||
{
|
||||
exec_direct("PRAGMA journal_mode = WAL");
|
||||
exec_direct("PRAGMA synchronous = NORMAL");
|
||||
exec_direct("PRAGMA foreign_keys = ON");
|
||||
|
||||
constexpr const char* schema = R"sql(
|
||||
CREATE TABLE IF NOT EXISTS gemm (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
rocm_version TEXT NOT NULL CHECK(length(rocm_version) > 0),
|
||||
device_name TEXT NOT NULL CHECK(length(device_name) > 0),
|
||||
problem TEXT NOT NULL CHECK(json_valid(problem)),
|
||||
instance_name TEXT NOT NULL CHECK(length(instance_name) > 0),
|
||||
latency REAL CHECK(latency > 0),
|
||||
tflops REAL CHECK(tflops > 0),
|
||||
bandwidth REAL CHECK(bandwidth > 0)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency);
|
||||
CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC);
|
||||
)sql";
|
||||
|
||||
exec_direct(schema);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
bool check_if_record_problem(const std::string& rocm_version,
|
||||
const std::string& device_name,
|
||||
const GemmProblem& gemm_problem)
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
SELECT 1 FROM gemm
|
||||
WHERE rocm_version=?
|
||||
AND device_name=?
|
||||
AND problem=?
|
||||
LIMIT 1
|
||||
)sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
sqlite3_stmt* raw_stmt = stmt;
|
||||
int idx = 1;
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
std::ostringstream oss;
|
||||
oss << gemm_problem;
|
||||
auto problem_json = oss.str();
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(
|
||||
raw_stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
return (rc == SQLITE_ROW);
|
||||
}
|
||||
|
||||
std::tuple<std::string, PerformanceResult> query_cache(const std::string& rocm_version,
|
||||
const std::string& device_name,
|
||||
const GemmProblem& gemm_problem)
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
SELECT latency, tflops, bandwidth FROM gemm
|
||||
WHERE rocm_version=? AND device_name=?
|
||||
AND problem=?
|
||||
LIMIT 1
|
||||
)sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
sqlite3_stmt* raw_stmt = stmt;
|
||||
|
||||
int idx = 1;
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
std::ostringstream oss;
|
||||
oss << gemm_problem;
|
||||
auto problem_json = oss.str();
|
||||
CHECK_SQLITE3(sqlite3_bind_text(
|
||||
stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
|
||||
if(rc == SQLITE_ROW)
|
||||
{
|
||||
return std::make_tuple(reinterpret_cast<const char*>(sqlite3_column_text(raw_stmt, 0)),
|
||||
PerformanceResult{sqlite3_column_double(raw_stmt, 1),
|
||||
sqlite3_column_double(raw_stmt, 2),
|
||||
sqlite3_column_double(raw_stmt, 3)});
|
||||
}
|
||||
else if(rc != SQLITE_DONE)
|
||||
{
|
||||
throw std::runtime_error(sqlite3_errmsg(db_ptr_.get()));
|
||||
}
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
|
||||
return std::make_tuple("", PerformanceResult{-1.0f, -1.0f, -1.0f});
|
||||
}
|
||||
|
||||
void insert_cache(const std::string& rocm_version,
|
||||
const std::string& device_name,
|
||||
const std::vector<KernelInstance>& kernen_instnaces)
|
||||
{
|
||||
exec_direct("BEGIN TRANSACTION");
|
||||
try
|
||||
{
|
||||
constexpr const char* sql = R"sql(
|
||||
INSERT INTO gemm
|
||||
(rocm_version, device_name,
|
||||
problem, instance_name,
|
||||
latency, tflops, bandwidth)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
||||
)sql";
|
||||
|
||||
StmtWrapper stmt(db_ptr_.get(), sql);
|
||||
sqlite3_stmt* raw_stmt = stmt;
|
||||
|
||||
for(const auto& item : kernen_instnaces)
|
||||
{
|
||||
int idx = 1;
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
std::ostringstream oss;
|
||||
oss << item.problem_;
|
||||
auto problem_json = oss.str();
|
||||
CHECK_SQLITE3(sqlite3_bind_text(raw_stmt,
|
||||
idx++,
|
||||
problem_json.c_str(),
|
||||
problem_json.size(),
|
||||
SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(
|
||||
sqlite3_bind_text(raw_stmt, idx++, item.name_.c_str(), -1, SQLITE_TRANSIENT),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.latency_),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.tflops_),
|
||||
db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.bandwidth_),
|
||||
db_ptr_.get());
|
||||
|
||||
int rc;
|
||||
CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc);
|
||||
CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get());
|
||||
CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get());
|
||||
}
|
||||
exec_direct("COMMIT");
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
exec_direct("ROLLBACK");
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void exec_direct(const char* sql)
|
||||
{
|
||||
CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get());
|
||||
}
|
||||
|
||||
std::unique_ptr<sqlite3, decltype(&sqlite3_close)> db_ptr_;
|
||||
};
|
||||
Reference in New Issue
Block a user