diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 72bf1aa8a4..a75fff639d 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,9 +1,11 @@ +find_package(SQLite3 REQUIRED) + # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${CMAKE_CURRENT_BINARY_DIR} - # --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json + --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json --list_blobs RESULT_VARIABLE ret ) @@ -31,7 +33,7 @@ add_custom_command( OUTPUT ${GEMM_CODEGEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${CMAKE_CURRENT_BINARY_DIR} - # --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json + --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json --gen_blobs ) @@ -51,8 +53,8 @@ target_link_libraries(gemm_host_api INTERFACE gemm_template_instances) add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp) target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp) -target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api) +target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp profile_cache_db.hpp) +target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api SQLite::SQLite3) set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS) diff --git a/tile_engine/ops/gemm/benchmark_gemm.cpp b/tile_engine/ops/gemm/benchmark_gemm.cpp index fb56e524d2..380568db1e 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.cpp +++ b/tile_engine/ops/gemm/benchmark_gemm.cpp @@ -27,6 +27,8 @@ void benchmark_gemm(const ck_tile::ArgParser& arg_parser) arg_parser.get_bool("structured_sparsity")}; Setting setting{ + arg_parser.get_bool("enable_profile_cache"), + arg_parser.get_bool("flush_profile_cache"), arg_parser.get_int("warmup"), arg_parser.get_int("repeat"), arg_parser.get_bool("timer"), diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp index 292d67dad6..30bd6cfe7d 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.hpp +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -118,6 +118,8 @@ struct KernelInstance struct Setting { + bool enable_profile_cache_; + bool flush_profile_cache_; int n_warmup_; int n_repeat_; bool is_gpu_timer_; diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index 8cbc3f26f6..280e228190 100755 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -73,6 +73,14 @@ inline auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") .insert("stride_c", "0", "The stride value for tensor C Default is 0.") .insert("split_k", "1", "The split value for k dimension. Default is 1.") + .insert("enable_profile_cache", + "true", + "whether use profile cache or not when benchmark kernel, Possible values are true " + "or false. Default is true") + .insert("flush_profile_cache", + "false", + "whether flush profile cache or not when benchmark kernel. Possible values are " + "true or false. Default is false") .insert("verify", "2", "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 9170952aa8..d548616f13 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -6,10 +6,12 @@ #include #include #include +#include #include "ck_tile/host/device_prop.hpp" #include "ck_tile/ops/gemm.hpp" #include "benchmark_gemm.hpp" +#include "profile_cache_db.hpp" class GemmProfiler { @@ -20,10 +22,45 @@ class GemmProfiler return instance; } + bool is_problem_record_cache(const GemmProblem& gemm_problem) + { + if(setting_.enable_profile_cache_) + { + if(!cache_db_->check_if_record_problem( + get_rocm_version(), ck_tile::get_device_name(), gemm_problem)) + { + return false; + } + else + { + auto [name, perf_result] = cache_db_->query_cache( + get_rocm_version(), ck_tile::get_device_name(), gemm_problem); + KernelInstance kernel_instance; + kernel_instance.problem_ = gemm_problem; + kernel_instance.name_ = name; + kernel_instance.perf_result_.latency_ = perf_result.latency_; + kernel_instance.perf_result_.tflops_ = perf_result.tflops_; + kernel_instance.perf_result_.bandwidth_ = perf_result.bandwidth_; + std::cout << "Skip this problem for " << gemm_problem + << ", Because it has already been recorded in the cache database" + << std::endl; + kernel_instances_.emplace_back(kernel_instance); + return true; + } + } + else + { + return false; + } + } + void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) { + if(is_problem_record_cache(gemm_problem)) + return; + const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; const CLayout layout_c = CLayout{}; @@ -135,6 +172,12 @@ class GemmProfiler c_m_n_dev_result, kernel_run_result); } + + if(setting_.enable_profile_cache_) + { + cache_db_->insert_cache( + get_rocm_version(), ck_tile::get_device_name(), kernel_instances_); + } } void process_result(const GemmProblem& gemm_problem, @@ -247,14 +290,92 @@ class GemmProfiler return kernel_instance; } - GemmProfiler(const GemmProfiler&) = delete; + GemmProfiler(const GemmProfiler&) = delete; GemmProfiler& operator=(const GemmProfiler&) = delete; private: ~GemmProfiler() { kernel_instances_.clear(); } - GemmProfiler(Setting setting) : setting_(setting) {} + GemmProfiler(Setting setting) : setting_(setting) { initialize_profile_cache(); } + + void initialize_profile_cache() + { + if(setting_.enable_profile_cache_) + { + std::filesystem::path cache_db_prefix_path = + std::filesystem::current_path() / ".tile_engine"; + if(!create_cache_directory(cache_db_prefix_path)) + { + std::cerr << "Error: Failed to create cache directory" << std::endl; + return; + } + std::filesystem::path cache_db_path = + cache_db_prefix_path / ("tile_engine_" + ck_tile::get_device_name() + ".db"); + + handle_flush_cache(cache_db_path); + + initialize_cache_db(cache_db_path); + } + else + { + std::cout << "Gemm profiler disable profile cache! " << std::endl; + } + } + + bool create_cache_directory(const std::filesystem::path& cache_db_prefix_path) + { + std::error_code ec; + bool created = std::filesystem::create_directories(cache_db_prefix_path, ec); + + if(ec) + { + std::cerr << "Error creating directory " << cache_db_prefix_path << ": " << ec.message() + << std::endl; + return false; + } + + if(created) + + { + std::cout << "Created cache directory: " << cache_db_prefix_path << std::endl; + } + else + { + std::cout << "Using existing cache directory: " << cache_db_prefix_path << std::endl; + } + return true; + } + + void handle_flush_cache(const std::filesystem::path& cache_db_path) const + { + if(setting_.flush_profile_cache_ && std::filesystem::exists(cache_db_path)) + { + std::error_code ec; + if(std::filesystem::remove(cache_db_path, ec)) + + { + std::cout << "Successfully flushed cache: " << cache_db_path << std::endl; + } + else + { + std::cerr << "Error flushing cache: " << ec.message() << std::endl; + } + } + } + + void initialize_cache_db(const std::filesystem::path& path) + { + try + { + cache_db_ = std::make_unique(path); + std::cout << "Loaded profile cache from " << path << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Failed to initialize profile cache: " << e.what() << std::endl; + } + } Setting setting_; - + std::unique_ptr cache_db_; std::vector kernel_instances_; }; diff --git a/tile_engine/ops/gemm/profile_cache_db.hpp b/tile_engine/ops/gemm/profile_cache_db.hpp new file mode 100644 index 0000000000..dbf98defa0 --- /dev/null +++ b/tile_engine/ops/gemm/profile_cache_db.hpp @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "benchmark_gemm.hpp" + +#define CHECK_SQLITE3(expr, db) \ + do \ + { \ + int result_code = (expr); \ + if(result_code != SQLITE_OK) \ + { \ + const char* err = sqlite3_errmsg(db); \ + throw std::runtime_error("SQLite error[" + std::to_string(result_code) + \ + "]: " + (err ? err : "unknown error") + " at " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); \ + } \ + } while(0) + +#define CHECK_SQLITE3_RC(expr, db, rc) \ + do \ + { \ + rc = (expr); \ + if(rc != SQLITE_OK && rc != SQLITE_ROW && rc != SQLITE_DONE) \ + { \ + const char* err = sqlite3_errmsg(db); \ + throw std::runtime_error("SQLite error[" + std::to_string(rc) + \ + "]: " + (err ? err : "unknown error") + " at " + \ + std::string(__FILE__) + ":" + std::to_string(__LINE__)); \ + } \ + } while(0) + +class StmtWrapper +{ + public: + explicit StmtWrapper(sqlite3* db, const char* sql) + : stmt_( + [db, sql] { + sqlite3_stmt* stmt = nullptr; + CHECK_SQLITE3(sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr), db); + return stmt; + }(), + &sqlite3_finalize) + { + } + + operator sqlite3_stmt*() const { return stmt_.get(); } + + private: + std::unique_ptr stmt_; +}; + +class ProfileCacheDB +{ + public: + explicit ProfileCacheDB(const std::filesystem::path& path) + : db_ptr_( + [path] { + sqlite3* raw_db_ptr = nullptr; + CHECK_SQLITE3(sqlite3_open_v2(path.string().c_str(), + &raw_db_ptr, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, + nullptr), + raw_db_ptr); + return raw_db_ptr; + }(), + &sqlite3_close) + { + + try + { + exec_direct("PRAGMA journal_mode = WAL"); + exec_direct("PRAGMA synchronous = NORMAL"); + exec_direct("PRAGMA foreign_keys = ON"); + + constexpr const char* schema = R"sql( + CREATE TABLE IF NOT EXISTS gemm ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + rocm_version TEXT NOT NULL CHECK(length(rocm_version) > 0), + device_name TEXT NOT NULL CHECK(length(device_name) > 0), + problem TEXT NOT NULL CHECK(json_valid(problem)), + instance_name TEXT NOT NULL CHECK(length(instance_name) > 0), + latency REAL CHECK(latency > 0), + tflops REAL CHECK(tflops > 0), + bandwidth REAL CHECK(bandwidth > 0) + ); + CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency); + CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC); + CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC); + )sql"; + + exec_direct(schema); + } + catch(...) + { + throw; + } + } + + bool check_if_record_problem(const std::string& rocm_version, + const std::string& device_name, + const GemmProblem& gemm_problem) + { + constexpr const char* sql = R"sql( + SELECT 1 FROM gemm + WHERE rocm_version=? + AND device_name=? + AND problem=? + LIMIT 1 + )sql"; + + StmtWrapper stmt(db_ptr_.get(), sql); + sqlite3_stmt* raw_stmt = stmt; + int idx = 1; + CHECK_SQLITE3( + sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT), + db_ptr_.get()); + CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT), + db_ptr_.get()); + std::ostringstream oss; + oss << gemm_problem; + auto problem_json = oss.str(); + CHECK_SQLITE3( + sqlite3_bind_text( + raw_stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT), + db_ptr_.get()); + + int rc; + CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc); + CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get()); + CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get()); + return (rc == SQLITE_ROW); + } + + std::tuple query_cache(const std::string& rocm_version, + const std::string& device_name, + const GemmProblem& gemm_problem) + { + constexpr const char* sql = R"sql( + SELECT latency, tflops, bandwidth FROM gemm + WHERE rocm_version=? AND device_name=? + AND problem=? + LIMIT 1 + )sql"; + + StmtWrapper stmt(db_ptr_.get(), sql); + sqlite3_stmt* raw_stmt = stmt; + + int idx = 1; + CHECK_SQLITE3( + sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT), + db_ptr_.get()); + CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT), + db_ptr_.get()); + std::ostringstream oss; + oss << gemm_problem; + auto problem_json = oss.str(); + CHECK_SQLITE3(sqlite3_bind_text( + stmt, idx++, problem_json.c_str(), problem_json.size(), SQLITE_TRANSIENT), + db_ptr_.get()); + + int rc; + CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc); + + if(rc == SQLITE_ROW) + { + return std::make_tuple(reinterpret_cast(sqlite3_column_text(raw_stmt, 0)), + PerformanceResult{sqlite3_column_double(raw_stmt, 1), + sqlite3_column_double(raw_stmt, 2), + sqlite3_column_double(raw_stmt, 3)}); + } + else if(rc != SQLITE_DONE) + { + throw std::runtime_error(sqlite3_errmsg(db_ptr_.get())); + } + CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get()); + CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get()); + + return std::make_tuple("", PerformanceResult{-1.0f, -1.0f, -1.0f}); + } + + void insert_cache(const std::string& rocm_version, + const std::string& device_name, + const std::vector& kernen_instnaces) + { + exec_direct("BEGIN TRANSACTION"); + try + { + constexpr const char* sql = R"sql( + INSERT INTO gemm + (rocm_version, device_name, + problem, instance_name, + latency, tflops, bandwidth) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) + )sql"; + + StmtWrapper stmt(db_ptr_.get(), sql); + sqlite3_stmt* raw_stmt = stmt; + + for(const auto& item : kernen_instnaces) + { + int idx = 1; + CHECK_SQLITE3( + sqlite3_bind_text(raw_stmt, idx++, rocm_version.c_str(), -1, SQLITE_TRANSIENT), + db_ptr_.get()); + CHECK_SQLITE3( + sqlite3_bind_text(raw_stmt, idx++, device_name.c_str(), -1, SQLITE_TRANSIENT), + db_ptr_.get()); + std::ostringstream oss; + oss << item.problem_; + auto problem_json = oss.str(); + CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, + idx++, + problem_json.c_str(), + problem_json.size(), + SQLITE_TRANSIENT), + db_ptr_.get()); + CHECK_SQLITE3( + sqlite3_bind_text(raw_stmt, idx++, item.name_.c_str(), -1, SQLITE_TRANSIENT), + db_ptr_.get()); + CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.latency_), + db_ptr_.get()); + CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.tflops_), + db_ptr_.get()); + CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result_.bandwidth_), + db_ptr_.get()); + + int rc; + CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc); + CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get()); + CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get()); + } + exec_direct("COMMIT"); + } + catch(...) + { + exec_direct("ROLLBACK"); + throw; + } + } + + private: + void exec_direct(const char* sql) + { + CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get()); + } + + std::unique_ptr db_ptr_; +};