From 8fd4774f942d4e7ac8427bf29d593c76cf591b36 Mon Sep 17 00:00:00 2001 From: Saood Karim Date: Fri, 27 Jun 2025 18:36:10 -0500 Subject: [PATCH] Remove hardcoded extension and add error handling to extension loading --- common/common.cpp | 5 +++++ common/common.h | 1 + examples/server/server.cpp | 22 +++++++++++++++++----- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f8d644bd..b5496e27 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1410,6 +1410,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.sql_save_file = argv[i]; return true; } + if (arg == "--sqlite-zstd-ext-file") { + CHECK_ARG + params.sqlite_zstd_ext_file = argv[i]; + return true; + } if (arg == "--chat-template") { CHECK_ARG if (!llama_chat_verify_template(argv[i])) { diff --git a/common/common.h b/common/common.h index 7c4b787d..4a465250 100644 --- a/common/common.h +++ b/common/common.h @@ -243,6 +243,7 @@ struct gpt_params { std::string slot_save_path; std::string sql_save_file; + std::string sqlite_zstd_ext_file; float slot_prompt_similarity = 0.5f; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index caf9ced1..140b367e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -37,10 +37,6 @@ struct DatabaseHandle { sqlite::database db; DatabaseHandle(const std::string& path) : db(path) { - - sqlite3_enable_load_extension(db.connection().get(), 1); - db << "SELECT load_extension('/home/saood06/mikupadStuff/libsqlite_zstd.so')"; - db << "CREATE TABLE IF NOT EXISTS sessions (key TEXT PRIMARY KEY, data TEXT)"; db << "CREATE TABLE IF NOT EXISTS templates (key TEXT PRIMARY KEY, data TEXT)"; db << "CREATE TABLE IF NOT EXISTS names (key TEXT PRIMARY KEY, data TEXT)"; @@ -2937,7 +2933,23 @@ int main(int argc, char ** argv) { ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; auto db_handle = std::make_shared(params.sql_save_file); - + if (!params.sqlite_zstd_ext_file.empty()) { + auto* conn = db_handle->db.connection().get(); + sqlite3_enable_load_extension(conn, 1); + char* errmsg = nullptr; + const int rc = sqlite3_load_extension( + conn, + params.sqlite_zstd_ext_file.c_str(), + nullptr, + &errmsg + ); + if(rc != SQLITE_OK) { + const std::string err = errmsg ? errmsg : "Unknown extension error"; + sqlite3_free(errmsg); + LOG_WARNING("Failed to load extension", {{"err", err}}); + } + sqlite3_enable_load_extension(conn, 0); + } // load the model if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR);