From d3d32843b570d7322e781d21150bef2af0aedc75 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Thu, 1 May 2025 11:05:27 +0000 Subject: [PATCH] only support profile --- tile_engine/ops/gemm/CMakeLists.txt | 40 ++- tile_engine/ops/gemm/README.md | 3 +- tile_engine/ops/gemm/benchmark_gemm.hpp | 231 ++------------ tile_engine/ops/gemm/gemm_host_api.cpp | 18 +- tile_engine/ops/gemm/gemm_host_api.hpp | 10 +- tile_engine/ops/gemm/gemm_instance_builder.py | 62 ++-- tile_engine/ops/gemm/profile_cache.hpp | 294 ------------------ 7 files changed, 81 insertions(+), 577 deletions(-) delete mode 100644 tile_engine/ops/gemm/profile_cache.hpp diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 16870437a8..44d66d17e0 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,36 +1,34 @@ # generate a list of kernels, but not actually emit files at config stage -# execute_process( -# COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py -# --working_path ${CMAKE_CURRENT_BINARY_DIR} -# --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json -# --list_blobs -# RESULT_VARIABLE ret -# ) +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --list_blobs + RESULT_VARIABLE ret +) -# if(ret AND NOT ret EQUAL 0) -# message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") -# endif() +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() -# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) -# add_custom_command( -# OUTPUT ${GEMM_CODEGEN_BLOBS} -# COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py -# --working_path ${CMAKE_CURRENT_BINARY_DIR} -# --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json -# --gen_blobs -# DEPENDS ${GEMM_CODEGEN_BLOBS} -# ) +add_custom_command( + OUTPUT ${GEMM_CODEGEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --gen_blobs + DEPENDS ${GEMM_CODEGEN_BLOBS} +) set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") message("adding example ${EXECUTABLE_GEMM_INSTANCE}") # use build as include directory -find_package(SQLite3 REQUIRED) include_directories(${CMAKE_CURRENT_BINARY_DIR}) add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp) target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_link_libraries(${EXECUTABLE_GEMM_INSTANCE} SQLite::SQLite3) target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS}) set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS) \ No newline at end of file diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 495232f19b..4addfa0d5b 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -28,10 +28,11 @@ make tile_engine_gemm -j -stride_c Tensor C stride (default:0) -split_k SplitK value (default:1) -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) + -metric The metric value of kernel performance - latency: 0, tflops: 1, bandwidth: 2 (default:0) -warmup Number of iterations before benchmark the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) + -init Value for initializing tensor - random: 0, linear: 1, constant(1): 2 (default:0) -pipeline possible values are: compv3, compv4, mem (default:compv3) -scheduler possible values are: intrawave, interwave (default:intrawave) -epilogue possible values are: cshuffle, default (default:cshuffle) diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp index b6ab76b9e2..72fca244f6 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.hpp +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -8,17 +8,15 @@ #include #include +#include "ck/version.h" #include "ck/host_utility/device_prop.hpp" -#include "profile_cache.hpp" -class Executor +class Profiler { public: - ~Executor() { kernel_instances_.clear(); } - - static Executor& instance(bool enable_profile_cache = true, bool flush_profile_cache = false) + static Profiler& instance() { - static Executor instance{enable_profile_cache, flush_profile_cache}; + static Profiler instance; return instance; } @@ -61,108 +59,39 @@ class Executor KernelInstance kernel_instance{environment_, description, problem, {-1.0f, -1.0f, -1.0f}}; - auto launch_kernel = [&] { - float avg_time = Kernel::launch(args, s); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + float avg_time = Kernel::launch(args, s); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - std::size_t flop = std::size_t(2) * args.M * args.N * args.K; - std::size_t num_byte = sizeof(ADataType) * args.M * args.K + - sizeof(BDataType) * args.N * args.K + - sizeof(CDataType) * args.M * args.N; - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_byte / 1.E6 / avg_time; + std::size_t flop = std::size_t(2) * args.M * args.N * args.K; + std::size_t num_byte = sizeof(ADataType) * args.M * args.K + + sizeof(BDataType) * args.N * args.K + + sizeof(CDataType) * args.M * args.N; + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_byte / 1.E6 / avg_time; - kernel_instance.perf_result.latency = avg_time; - kernel_instance.perf_result.tflops = tflops; - kernel_instance.perf_result.bandwidth = gb_per_sec; + kernel_instance.perf_result.latency = avg_time; + kernel_instance.perf_result.tflops = tflops; + kernel_instance.perf_result.bandwidth = gb_per_sec; - std::cout << kernel_instance << std::endl; + std::cout << kernel_instance << std::endl; - bool verified_correct = - !verify || compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result); + bool verified_correct = + !verify || compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result); - if(verified_correct) - { - kernel_instances_.emplace_back(kernel_instance); - } - else - { - std::cout << "Verification failed, skip kernel: " << description << std::endl; - } - - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); - }; - - if(enable_profile_cache_) + if(verified_correct) { - if(!cache_db_->check_if_record(kernel_instance)) - { - - launch_kernel(); - cache_db_->insert_batch({kernel_instance}); - } - else - { - auto perf_result = cache_db_->query_performance_result(kernel_instance); - kernel_instance.perf_result.latency = perf_result.latency; - kernel_instance.perf_result.tflops = perf_result.tflops; - kernel_instance.perf_result.bandwidth = perf_result.bandwidth; - std::cout << "Skip this kernel for " << description - << ", Because it has already been recorded in the cache database" - << std::endl; - kernel_instances_.emplace_back(kernel_instance); - } + kernel_instances_.emplace_back(kernel_instance); } else { - launch_kernel(); + std::cout << "Verification failed, skip kernel: " << description << std::endl; } + + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); } - void export_perf_to_csv(const std::vector& instances, - const std::string& filename) - { - - std::ostringstream buffer; - - buffer << "ROCmVersion,CommitID,DeviceName,KernelName,SplitK,M,N,K," - << "StrideA,StrideB,StrideC,ADataType,BDataType,AccDataType," - << "CDataType,ALayout,BLayout,CLayout,Latency(ms),TFLOPS,Bandwidth\n"; - - for(const auto& instance : instances) - { - const auto& env = instance.env; - const auto& p = instance.problem; - const auto& perf = instance.perf_result; - - std::string sanitized_name = instance.name; - std::replace(sanitized_name.begin(), sanitized_name.end(), '\"', '\''); - - buffer << env.rocm_version << "," << env.commit_id << "," << env.device_name << "," - << "\"" << sanitized_name << "\"," << p.split_k << "," << p.m << "," << p.n - << "," << p.k << "," << p.stride_a << "," << p.stride_b << "," << p.stride_c - << "," << p.dtype_a << "," << p.dtype_b << "," << p.dtype_acc << "," << p.dtype_c - << "," << p.layout_a << "," << p.layout_b << "," << p.layout_c << "," - << std::fixed << std::setprecision(6) << perf.latency << "," << std::scientific - << perf.tflops << "," << std::fixed << perf.bandwidth << "\n"; - } - - std::ofstream csv_file(filename, std::ios::trunc); - if(!csv_file) - { - throw std::runtime_error("Failed to open CSV file: " + filename); - } - csv_file << buffer.str(); - csv_file.close(); - - if(csv_file.fail()) - { - throw std::runtime_error("Incomplete write to CSV file: " + filename); - } - } - - KernelInstance select_best_instance(Metric metric, const std::string& csv_path = "") + KernelInstance select_best_instance(Metric metric) { if(kernel_instances_.empty()) throw std::runtime_error("Empty instances"); @@ -174,126 +103,28 @@ class Executor b.perf_result, a.perf_result, metric); }); + std::cout << "**********************************" << std::endl; std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" << "The best kernel instance is: " << kernel_instance << std::endl; - - if(!csv_path.empty()) - { - try - { - export_perf_to_csv(kernel_instances_, csv_path); - } - catch(const std::exception& e) - { - std::cerr << "CSV export failed: " << e.what() << std::endl; - } - } + std::cout << "**********************************" << std::endl; return kernel_instance; } private: - Executor(bool enable_profile_cache = true, bool flush_profile_cache = false) - : enable_profile_cache_(enable_profile_cache), flush_profile_cache_(flush_profile_cache) + Profiler() { environment_ = Environment{ get_rocm_version(), - "89f", ck::get_device_name(), }; - std::cout << "Init gemm bechmark on device: " << environment_.device_name << std::endl; - - initialize_profile_cache(); } + ~Profiler() { kernel_instances_.clear(); } - void initialize_profile_cache() - { - // Init cache if enable profile cache - if(enable_profile_cache_) - { - // get profile cache path - std::filesystem::path cache_db_prefix_path = - std::filesystem::current_path() / ".tile_engine"; - if(!create_cache_directory(cache_db_prefix_path)) - { - std::cerr << "Error: Failed to create cache directory" << std::endl; - return; - } - std::filesystem::path cache_db_path = - cache_db_prefix_path / ("tile_engine_" + environment_.device_name + ".db"); - - // remove cache if flush_profile_cache - handle_cache_flush(cache_db_path); - - // load profile cache - initialize_cache_db(cache_db_path); - } - else - { - std::cout << "Executor disable profile cache! " << std::endl; - } - } - - bool create_cache_directory(const std::filesystem::path& cache_db_prefix_path) - { - std::error_code ec; - bool created = std::filesystem::create_directories(cache_db_prefix_path, ec); - - if(ec) - { - std::cerr << "Error creating directory " << cache_db_prefix_path << ": " << ec.message() - << std::endl; - return false; - } - - if(created) - { - std::cout << "Created cache directory: " << cache_db_prefix_path << std::endl; - } - else - { - std::cout << "Using existing cache directory: " << cache_db_prefix_path << std::endl; - } - return true; - } - - void handle_cache_flush(const std::filesystem::path& cache_db_path) const - { - if(flush_profile_cache_ && std::filesystem::exists(cache_db_path)) - { - std::error_code ec; - if(std::filesystem::remove(cache_db_path, ec)) - { - std::cout << "Successfully flushed cache: " << cache_db_path << std::endl; - } - else - { - std::cerr << "Error flushing cache: " << ec.message() << std::endl; - } - } - } - - void initialize_cache_db(const std::filesystem::path& path) - { - try - { - cache_db_ = std::make_unique(path); - std::cout << "Loaded profile cache from " << path << std::endl; - } - catch(const std::exception& e) - { - std::cerr << "Failed to initialize profile cache: " << e.what() << std::endl; - } - } - - Executor(const Executor&) = delete; - Executor& operator=(const Executor&) = delete; + Profiler(const Profiler&) = delete; + Profiler& operator=(const Profiler&) = delete; Environment environment_; - bool enable_profile_cache_; - bool flush_profile_cache_; - - std::unique_ptr cache_db_; std::vector kernel_instances_; }; diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp index cae03c303e..969a0ab86f 100644 --- a/tile_engine/ops/gemm/gemm_host_api.cpp +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -11,22 +11,12 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_dev_result, int verify, int metric, - bool enable_profile_cache, - bool flush_profile_cache, KernelTraits& trait, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { - return GemmDispatcher::dispatch(c_m_n_dev_buf, - c_m_n_host_result, - c_m_n_dev_result, - verify, - metric, - enable_profile_cache, - flush_profile_cache, - trait, - args, - s); + return GemmDispatcher::dispatch( + c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, metric, trait, args, s); } template #include #include +#include "gemm_common.hpp" +#include "gemm_instances.hpp" +#include "gemm_host_api.hpp" +#include "benchmark_gemm.hpp" + struct GemmDispatcher { static auto& get_kernel_map() { // Use a static local variable - static std::unordered_map& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map; + static std::unordered_map&, + ck_tile::HostTensor&, + int, + ck_tile::GemmHostArgs&, + const ck_tile::stream_config&)>> + kernel_map; return kernel_map; } @@ -513,7 +519,8 @@ struct GemmDispatcher { for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf, + content += f""" kernel_map["{group}"] = [](Profiler& profiler, + ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, int verify, ck_tile::GemmHostArgs& args, @@ -526,46 +533,29 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);""" + profiler.benchmark_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);""" content += f""" }};\n""" content += """ } - template - static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) - { - float avg_time = Kernel::launch(args, s); - std::string description = Kernel::get_name(); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - - std::size_t flop = std::size_t(2) * args.M * args.N * args.K; - std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N; - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_byte / 1.E6 / avg_time; - - std::cout << "Performance for " << description << " : " << avg_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - - if(verify) - compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); - } - static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, - int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + int verify, + int metric, + const KernelTraits& trait, + ck_tile::GemmHostArgs& gemm_args, const ck_tile::stream_config& s) { init(); const std::string key = assemble_key(trait); auto& kernel_map = get_kernel_map(); + auto& profiler = Profiler::instance(); if(auto it = kernel_map.find(key); it != kernel_map.end()) { - return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s); + it->second( + profiler, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, s); + profiler.select_best_instance(static_cast(metric)); + return; } throw std::runtime_error("No suitable kernel found: " + key); } diff --git a/tile_engine/ops/gemm/profile_cache.hpp b/tile_engine/ops/gemm/profile_cache.hpp deleted file mode 100644 index ddb2d899e7..0000000000 --- a/tile_engine/ops/gemm/profile_cache.hpp +++ /dev/null @@ -1,294 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "gemm_host_api.hpp" - -#define CHECK_SQLITE3(expr, db) \ - do \ - { \ - int result_code = (expr); \ - if(result_code != SQLITE_OK) \ - { \ - const char* err = sqlite3_errmsg(db); \ - throw std::runtime_error("SQLite error[" + std::to_string(result_code) + \ - "]: " + (err ? err : "unknown error") + " at " + \ - std::string(__FILE__) + ":" + std::to_string(__LINE__)); \ - } \ - } while(0) - -#define CHECK_SQLITE3_RC(expr, db, rc) \ - do \ - { \ - rc = (expr); \ - if(rc != SQLITE_OK && rc != SQLITE_ROW && rc != SQLITE_DONE) \ - { \ - const char* err = sqlite3_errmsg(db); \ - throw std::runtime_error("SQLite error[" + std::to_string(rc) + \ - "]: " + (err ? err : "unknown error") + " at " + \ - std::string(__FILE__) + ":" + std::to_string(__LINE__)); \ - } \ - } while(0) - -class StmtWrapper -{ - public: - explicit StmtWrapper(sqlite3* db, const char* sql) - : stmt_( - [db, sql] { - sqlite3_stmt* stmt = nullptr; - CHECK_SQLITE3(sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr), db); - return stmt; - }(), - &sqlite3_finalize) - { - } - - operator sqlite3_stmt*() const { return stmt_.get(); } - - private: - std::unique_ptr stmt_; -}; - -class ProfileCacheDB -{ - public: - explicit ProfileCacheDB(const std::filesystem::path& path) - : db_ptr_( - [path] { - sqlite3* raw_db_ptr = nullptr; - CHECK_SQLITE3(sqlite3_open_v2(path.string().c_str(), - &raw_db_ptr, - SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, - nullptr), - raw_db_ptr); - return raw_db_ptr; - }(), - &sqlite3_close) - { - - try - { - exec_direct("PRAGMA journal_mode = WAL"); - exec_direct("PRAGMA synchronous = NORMAL"); - exec_direct("PRAGMA foreign_keys = ON"); - - constexpr const char* schema = R"sql( - CREATE TABLE IF NOT EXISTS gemm ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - rocm_version TEXT NOT NULL CHECK(length(rocm_version) > 0), - commit_id TEXT NOT NULL CHECK(length(commit_id) > 0), - device_name TEXT NOT NULL CHECK(length(device_name) > 0), - instance_name TEXT NOT NULL CHECK(length(instance_name) > 0), - problem TEXT NOT NULL CHECK(json_valid(problem)), - latency REAL CHECK(latency > 0), - tflops REAL CHECK(tflops > 0), - bandwidth REAL CHECK(bandwidth > 0) - ); - CREATE INDEX IF NOT EXISTS idx_latency ON gemm(latency); - CREATE INDEX IF NOT EXISTS idx_tflops_desc ON gemm(tflops DESC); - CREATE INDEX IF NOT EXISTS idx_bandwidth_desc ON gemm(bandwidth DESC); - )sql"; - - exec_direct(schema); - } - catch(...) - { - throw; - } - } - - bool check_if_record(const KernelInstance& kernel_instance) - { - constexpr const char* sql = R"sql( - SELECT 1 FROM gemm - WHERE rocm_version=? AND commit_id=? AND device_name=? - AND instance_name=? AND problem=? - LIMIT 1)sql"; - - StmtWrapper stmt(db_ptr_.get(), sql); - sqlite3_stmt* raw_stmt = stmt; - int idx = 1; - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, kernel_instance.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, kernel_instance.env.commit_id.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, kernel_instance.env.device_name.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text(raw_stmt, idx++, kernel_instance.name.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, - idx++, - kernel_instance.problem.serialize().c_str(), - kernel_instance.problem.serialize().size(), - SQLITE_TRANSIENT), - db_ptr_.get()); - - int rc; - CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc); - CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get()); - CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get()); - return (rc == SQLITE_ROW); - } - - PerformanceResult query_performance_result(const KernelInstance& kernel_instance) - { - constexpr const char* sql = R"sql( - SELECT latency, tflops, bandwidth FROM gemm - WHERE rocm_version=? AND commit_id=? AND device_name=? - AND instance_name=? AND problem=? - LIMIT 1)sql"; - - StmtWrapper stmt(db_ptr_.get(), sql); - sqlite3_stmt* raw_stmt = stmt; - - int idx = 1; - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, kernel_instance.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, kernel_instance.env.commit_id.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, kernel_instance.env.device_name.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text(raw_stmt, idx++, kernel_instance.name.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3(sqlite3_bind_text(stmt, - idx++, - kernel_instance.problem.serialize().c_str(), - kernel_instance.problem.serialize().size(), - SQLITE_TRANSIENT), - db_ptr_.get()); - - int rc; - CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc); - - if(rc == SQLITE_ROW) - { - return {sqlite3_column_double(raw_stmt, 0), - sqlite3_column_double(raw_stmt, 1), - sqlite3_column_double(raw_stmt, 2)}; - } - else if(rc != SQLITE_DONE) - { - throw std::runtime_error(sqlite3_errmsg(db_ptr_.get())); - } - CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get()); - CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get()); - - return {-1.0f, -1.0f, -1.0f}; - } - - void insert_batch(const std::vector& data) - { - exec_direct("BEGIN TRANSACTION"); - try - { - constexpr const char* sql = R"sql( - INSERT INTO gemm - (rocm_version, commit_id, device_name, - instance_name, problem, - latency, tflops, bandwidth) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8))sql"; - - StmtWrapper stmt(db_ptr_.get(), sql); - sqlite3_stmt* raw_stmt = stmt; - - for(const auto& item : data) - { - int idx = 1; - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, item.env.rocm_version.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, item.env.commit_id.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text( - raw_stmt, idx++, item.env.device_name.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3( - sqlite3_bind_text(raw_stmt, idx++, item.name.c_str(), -1, SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3(sqlite3_bind_text(raw_stmt, - idx++, - item.problem.serialize().c_str(), - item.problem.serialize().size(), - SQLITE_TRANSIENT), - db_ptr_.get()); - CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.latency), - db_ptr_.get()); - CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.tflops), - db_ptr_.get()); - CHECK_SQLITE3(sqlite3_bind_double(raw_stmt, idx++, item.perf_result.bandwidth), - db_ptr_.get()); - - int rc; - CHECK_SQLITE3_RC(sqlite3_step(raw_stmt), db_ptr_.get(), rc); - CHECK_SQLITE3(sqlite3_reset(raw_stmt), db_ptr_.get()); - CHECK_SQLITE3(sqlite3_clear_bindings(raw_stmt), db_ptr_.get()); - } - exec_direct("COMMIT"); - } - catch(...) - { - exec_direct("ROLLBACK"); - throw; - } - } - - // std::vector query_top(Metric metric, int limit = 5) - // { - // if(limit <= 0) - // throw invalid_argument("Limit must be positive"); - - // const char* order = nullptr; - // switch(metric) - // { - // case LATENCY: order = "latency ASC"; break; - // case TFLOPS: order = "tflops DESC"; break; - // case BANDWIDTH: order = "bandwidth DESC"; break; - // default: throw invalid_argument("Invalid metric"); - // } - - // string sql = "SELECT name, latency, tflops, bandwidth FROM kernels " - // "ORDER BY " + - // string(order) + " LIMIT ?"; - - // StmtWrapper stmt(db, sql.c_str()); - // sqlite3_bind_int(stmt, 1, limit); - - // std::vector results; - // while(sqlite3_step(stmt) == SQLITE_ROW) - // { - // results.emplace_back(reinterpret_cast(sqlite3_column_text(stmt, 0)), - // PerformanceResult{sqlite3_column_double(stmt, 1), - // sqlite3_column_double(stmt, 2), - // sqlite3_column_double(stmt, 3)}); - // } - // return results; - // } - - private: - void exec_direct(const char* sql) - { - CHECK_SQLITE3(sqlite3_exec(db_ptr_.get(), sql, nullptr, nullptr, nullptr), db_ptr_.get()); - } - - std::unique_ptr db_ptr_; -};