diff --git a/common/common.cpp b/common/common.cpp index 2bb83b3e..c1e94323 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1478,6 +1478,16 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--sql-save-file") { + CHECK_ARG + params.sql_save_file = argv[i]; + return true; + } + if (arg == "--sqlite-zstd-ext-file") { + CHECK_ARG + params.sqlite_zstd_ext_file = argv[i]; + return true; + } if (arg == "--chat-template") { CHECK_ARG if (!llama_chat_verify_template(nullptr, argv[i], false)) { diff --git a/common/common.h b/common/common.h index 6859c16a..1bf0f235 100644 --- a/common/common.h +++ b/common/common.h @@ -250,6 +250,8 @@ struct gpt_params { bool log_json = false; std::string slot_save_path; + std::string sql_save_file; + std::string sqlite_zstd_ext_file; float slot_prompt_similarity = 0.5f; diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 20ddc5c5..9bc8017c 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,6 +1,7 @@ set(TARGET llama-server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) +option(LLAMA_SERVER_SQLITE3 "Build SQlite3 support for the server" OFF) include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) @@ -44,7 +45,7 @@ if (MSVC) ) endif() # target_link_libraries(${TARGET} PRIVATE "/STACK:104857600") -target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) +target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) @@ -53,6 +54,13 @@ if (LLAMA_SERVER_SSL) target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT) endif() +if (LLAMA_SERVER_SQLITE3) + find_package(SQLite3 REQUIRED) + target_link_libraries(${TARGET} PRIVATE SQLite::SQLite3) + target_include_directories(${TARGET} PUBLIC ./sqlite_modern_cpp/hdr) + target_compile_definitions(${TARGET} PRIVATE SQLITE3_MODERN_CPP_SUPPORT) +endif() + if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() diff --git a/examples/server/public_mikupad/index.html b/examples/server/public_mikupad/index.html new file mode 100644 index 00000000..0a0e92f1 --- /dev/null +++ b/examples/server/public_mikupad/index.html @@ -0,0 +1,7951 @@ + + + +mikupad + + + + + diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 68edf16e..4f9a86cc 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -38,6 +38,19 @@ #include #include #include +#ifdef SQLITE3_MODERN_CPP_SUPPORT +#include + +struct DatabaseHandle { + sqlite::database db; + + DatabaseHandle(const std::string& path) : db(path) { + db << "CREATE TABLE IF NOT EXISTS sessions (key TEXT PRIMARY KEY, data TEXT)"; + db << "CREATE TABLE IF NOT EXISTS templates (key TEXT PRIMARY KEY, data TEXT)"; + db << "CREATE TABLE IF NOT EXISTS names (key TEXT PRIMARY KEY, data TEXT)"; + } +}; +#endif using json = nlohmann::ordered_json; @@ -3441,7 +3454,32 @@ int main(int argc, char ** argv) { // Necessary similarity of prompt for slot selection ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - +#ifdef SQLITE3_MODERN_CPP_SUPPORT + auto db_handle = std::make_shared(params.sql_save_file); + bool sqlite_extension_loaded = false; + if (!params.sqlite_zstd_ext_file.empty()) { + auto* conn = db_handle->db.connection().get(); + sqlite3_enable_load_extension(conn, 1); + char* errmsg = nullptr; + const int rc = sqlite3_load_extension( + conn, + params.sqlite_zstd_ext_file.c_str(), + nullptr, + &errmsg + ); + if(rc != SQLITE_OK) { + const std::string err = errmsg ? errmsg : "Unknown extension error"; + sqlite3_free(errmsg); + LOG_WARNING("Failed to load extension", {{"err", err}}); + } + else { + sqlite_extension_loaded = true; + } + sqlite3_enable_load_extension(conn, 0); + } +#else + auto db_handle = false; +#endif // load the model if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR); @@ -3828,6 +3866,7 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params.n_parallel }, { "chat_template", ctx_server.chat_templates.template_default->source() }, + { "n_ctx", ctx_server.n_ctx } }; if (ctx_server.params.use_jinja && ctx_server.chat_templates.template_tool_use) { data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); @@ -4249,9 +4288,28 @@ int main(int argc, char ** argv) { std::vector tokens(n_token_count); file.read(reinterpret_cast(tokens.data()), tokens.size() * sizeof(llama_token)); + //C++17 is not modern enough to have a nice and portable way to get the mtime of a file + //so the following seems to be needed + auto ftime = fs::last_write_time(entry.path()); + auto system_time = std::chrono::time_point_cast( + ftime - fs::file_time_type::clock::now() + std::chrono::system_clock::now() + ); + std::time_t c_time = std::chrono::system_clock::to_time_t(system_time); + std::tm tm_struct; + #if defined(_WIN32) + localtime_s(&tm_struct, &c_time); + #else + localtime_r(&c_time, &tm_struct); + #endif + std::ostringstream oss; + oss << std::put_time(&tm_struct, "%Y-%m-%d %H:%M:%S"); + auto str_time = oss.str(); + + response.push_back({ {"filename", entry.path().filename().string()}, {"filesize", entry.file_size()}, + {"mtime", str_time}, {"token_count", n_token_count}, {"prompt", tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend())} }); @@ -4263,12 +4321,292 @@ int main(int argc, char ** argv) { res.set_content(response.dump(), "application/json; charset=utf-8"); }; + const auto list_slot_prompts = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json response = json::array(); + for (server_slot & slot : ctx_server.slots) { + response.push_back({ + {"slot_id", slot.id}, + {"token_count", slot.cache_tokens.size()}, + {"prompt", tokens_to_str(ctx_server.ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cend())} + }); + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + + const auto delete_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json response; + namespace fs = std::filesystem; + + try { + const json body = json::parse(req.body); + const std::string filename_str = body.at("filename"); + + // prevent directory traversal attacks + if (filename_str.find("..") != std::string::npos || filename_str.find('/') != std::string::npos || filename_str.find('\\') != std::string::npos) { + res.status = 400; + response = {{"error", "Invalid filename format."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + const fs::path file_to_delete = fs::path(params.slot_save_path) / fs::path(filename_str); + + if (!fs::exists(file_to_delete) || !fs::is_regular_file(file_to_delete)) { + res.status = 404; + response = {{"error", "File not found."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + if (fs::remove(file_to_delete)) { + response = { + {"status", "deleted"}, + {"filename", filename_str} + }; + } else { + res.status = 500; + response = {{"error", "Failed to delete the file."}}; + } + } catch (const json::parse_error& e) { + res.status = 400; + response = {{"error", "Invalid JSON request body."}}; + } catch (const json::out_of_range& e) { + res.status = 400; + response = {{"error", "Missing 'filename' key in request body."}}; + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + + const auto rename_saved_prompt = [&ctx_server, ¶ms](const httplib::Request& req, httplib::Response& res)-> void { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json response; + namespace fs = std::filesystem; + + try { + const json body = json::parse(req.body); + const std::string old_filename_str = body.at("old_filename"); + const std::string new_filename_str = body.at("new_filename"); + + if (old_filename_str.find("..") != std::string::npos || old_filename_str.find_first_of("/\\") != std::string::npos || + new_filename_str.find("..") != std::string::npos || new_filename_str.find_first_of("/\\") != std::string::npos) { + res.status = 400; + response = {{"error", "Invalid filename format."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + const fs::path old_path = fs::path(params.slot_save_path) / old_filename_str; + const fs::path new_path = fs::path(params.slot_save_path) / new_filename_str; + + if (!fs::exists(old_path) || !fs::is_regular_file(old_path)) { + res.status = 404; + response = {{"error", "Source file not found."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + if (fs::exists(new_path)) { + res.status = 409; + response = {{"error", "Destination filename already exists."}}; + res.set_content(response.dump(), "application/json; charset=utf-8"); + return; + } + + std::error_code ec; + fs::rename(old_path, new_path, ec); + + if (ec) { + res.status = 500; + response = {{"error", "Failed to rename file: " + ec.message()}}; + } else { + response = { + {"status", "renamed"}, + {"old_filename", old_filename_str}, + {"new_filename", new_filename_str} + }; + } + + } catch (const json::parse_error& e) { + res.status = 400; + response = {{"error", "Invalid JSON request body."}}; + } catch (const json::out_of_range& e) { + res.status = 400; + response = {{"error", "Missing 'old_filename' or 'new_filename' in request body."}}; + } catch (const std::exception& e) { + res.status = 500; + response = {{"error", e.what()}}; + } + + res.set_content(response.dump(), "application/json; charset=utf-8"); + }; + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(content), len, mime_type); return false; }; }; +#ifdef SQLITE3_MODERN_CPP_SUPPORT + const auto handle_version = [¶ms, sqlite_extension_loaded](const httplib::Request&, httplib::Response& res) { + res.set_content( + json{{"version", 4}, + {"features", {{"sql", !params.sql_save_file.empty()}, {"zstd_compression", sqlite_extension_loaded}}}}.dump(), + "application/json" + ); + }; +#else + const auto handle_version = [](const httplib::Request&, httplib::Response& res)-> void { + res.set_content( + json{{"version", 4}, + {"features", {{"sql", false}, {"zstd_compression", false}}}}.dump(), + "application/json" + ); + }; +#endif + +#ifdef SQLITE3_MODERN_CPP_SUPPORT + auto db_handler = [db_handle](auto func) { + return [func, db_handle](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", "*"); + try { + const json body = !req.body.empty() ? json::parse(req.body) : json::object(); + func(*db_handle, body, req, res); + } catch(const std::exception& e) { + res.status = 500; + res.set_content( + json{{"ok", false}, {"message", e.what()}}.dump(), + "application/json" + ); + } + }; + }; +#else + auto db_handler = [db_handle](auto func) { + return [func, db_handle](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", "*"); + res.status = 500; + res.set_content( + json{{"ok", false}, {"message", "Sqlite3 support was not enabled. Recompile with '-DLLAMA_SERVER_SQLITE3=ON'"}}.dump(), + "application/json" + ); + }; + }; +#endif + + const auto normalize_store_name = [](const std::string& storeName) { + if(storeName.empty()) return std::string("sessions"); + + std::string normalized; + normalized.reserve(storeName.size()); + + for(char c : storeName) { + if(std::isalpha(static_cast(c))) { + normalized.push_back(std::tolower(static_cast(c))); + } + } + + return normalized.empty() ? "sessions" : normalized; + }; + + const auto get_key_string = [](const json& j) { + return j.is_string() ? j.get() : j.dump(); + }; + + + const auto handle_load = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + std::string data; + const std::string store = normalize_store_name(body["storeName"]); + db.db << "SELECT data FROM " + store + " WHERE key = ?" << get_key_string(body["key"]) >> data; + if(data.empty()) { + res.status = 404; + res.set_content(json{{"ok", false}, {"message", "Key not found"}}.dump(), "application/json"); + } else { + json response{{"ok", true}}; + response["result"] = (store == "names") ? json(data) : json::parse(data); + res.set_content(response.dump(), "application/json"); + } + }); + + const auto handle_save = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + const std::string store = normalize_store_name(body["storeName"]); + const std::string data = (store == "names") ? body["data"].get() : body["data"].dump(); + db.db << "INSERT OR REPLACE INTO " + store + " (key, data) VALUES (?, ?)" << get_key_string(body["key"]) << data; + res.set_content(json{{"ok", true}, {"result", "Data saved successfully"}}.dump(), "application/json"); + }); + + const auto handle_rename = db_handler([get_key_string](auto& db, const json& body, auto&, auto& res) { + db.db << "UPDATE names SET data = ? WHERE key = ?" + << body["newName"].get() + << get_key_string(body["key"]); + res.set_content(json{{"ok", true}, {"result", "Session renamed successfully"}}.dump(), "application/json"); + }); + + const auto handle_all = db_handler([normalize_store_name](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT key, data FROM " + normalize_store_name(body["storeName"]) >> + [&](const std::string& key, const std::string& data) { + result[key] = json::parse(data); + }; + res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); + }); + + const auto handle_sessions = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT key, data FROM names" >> [&](const std::string& key, const std::string& data) { + result[key] = data; + }; + res.set_content(json{{"ok", true}, {"result", result}}.dump(), "application/json"); + }); + + const auto handle_delete = db_handler([normalize_store_name, get_key_string](auto& db, const json& body, auto&, auto& res) { + db.db << "DELETE FROM " + normalize_store_name(body["storeName"]) + " WHERE key = ?" + << get_key_string(body["key"]); + res.set_content(json{{"ok", true}, {"result", "Session deleted successfully"}}.dump(), "application/json"); + }); + + const auto handle_vacuum = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "VACUUM"; + res.set_content(json{"ok", true}.dump(), "application/json"); + }); + + const auto handle_zstd_get_configs = db_handler([](auto& db, const json& body, auto&, auto& res) { + json result = json::object(); + db.db << "SELECT id, config FROM _zstd_configs" >> [&](const std::string id, const std::string& config) { + result[id] = config; + }; + res.set_content(json{{"ok", true}, {"configs", result}}.dump(), "application/json"); + }); + + const auto handle_zstd_maintenance = db_handler([](auto& db, const json& body, auto&, auto& res) { + std::string data; + if (body["duration"].is_null()) { + db.db << "select zstd_incremental_maintenance(?, ?)" << nullptr << body["db_load"].get() >> data; + } + else { + db.db << "select zstd_incremental_maintenance(?, ?)" << body["duration"].get() << body["db_load"].get() >> data; + } + json response{{"ok", true}}; + response["result"] = json::parse(data); + res.set_content(response.dump(), "application/json"); + }); + + const auto handle_zstd_enable = db_handler([](auto& db, const json& body, auto&, auto& res) { + db.db << "select zstd_enable_transparent('{\"table\": \"" + body["table"].get() + "\",\"column\": \"" + body["column"].get() + "\", \"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"dict_chooser\": \"''a''\", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}')"; + res.set_content(json{"ok", true}.dump(), "application/json"); + }); + + const auto handle_zstd_config_update = db_handler([](auto& db, const json& body, auto&, auto& res) { + std::string patch_json = "{\"compression_level\": " + std::to_string(body["compression_level"].get()) + ", \"train_dict_samples_ratio\": " + std::to_string(body["train_dict_samples_ratio"].get()) + "}"; + db.db << "update _zstd_configs set config = json_patch(config, '" + patch_json + "')"; + res.set_content(json{{"ok", true}}.dump(), "application/json"); + }); // // Router @@ -4328,12 +4666,36 @@ int main(int argc, char ** argv) { svr->Post("/lora-adapters", handle_lora_adapters_apply); // Save & load slots svr->Get ("/slots", handle_slots); + svr->Get ("/slots/list", list_slot_prompts); if (!params.slot_save_path.empty()) { // these endpoints rely on slot_save_path existing svr->Post("/slots/:id_slot", handle_slots_action); svr->Get ("/list", list_saved_prompts); - } + svr->Post("/delete_prompt", delete_saved_prompt); + svr->Post("/rename_prompt", rename_saved_prompt); + } + svr->Get ("/version", handle_version); + if (!params.sql_save_file.empty()) { + // these endpoints rely on sql_save_file existing + svr->Post("/load", handle_load); + svr->Post("/save", handle_save); + svr->Post("/rename", handle_rename); + svr->Post("/all", handle_all); + svr->Post("/sessions", handle_sessions); + svr->Get ("/sessions", handle_sessions); + svr->Post("/delete", handle_delete); + //VACUUM is there for the extension but does not require the extension + svr->Get ("/vacuum", handle_vacuum); +#ifdef SQLITE3_MODERN_CPP_SUPPORT + if (sqlite_extension_loaded) { + svr->Get ("/zstd_get_configs", handle_zstd_get_configs); + svr->Post("/zstd_incremental_maintenance", handle_zstd_maintenance); + svr->Post("/zstd_enable_transparent", handle_zstd_enable); + svr->Post("/zstd_update_transparent", handle_zstd_config_update); + } +#endif + } // // Start the server // diff --git a/examples/server/sqlite_modern_cpp/License.txt b/examples/server/sqlite_modern_cpp/License.txt new file mode 100644 index 00000000..595b1d63 --- /dev/null +++ b/examples/server/sqlite_modern_cpp/License.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 aminroosta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp.h new file mode 100644 index 00000000..09665d38 --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp.h @@ -0,0 +1,682 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define MODERN_SQLITE_VERSION 3002008 + +#include + +#include "sqlite_modern_cpp/type_wrapper.h" +#include "sqlite_modern_cpp/errors.h" +#include "sqlite_modern_cpp/utility/function_traits.h" +#include "sqlite_modern_cpp/utility/uncaught_exceptions.h" +#include "sqlite_modern_cpp/utility/utf16_utf8.h" + +namespace sqlite { + + class database; + class database_binder; + + template class binder; + + typedef std::shared_ptr connection_type; + + template + struct index_binding_helper { + index_binding_helper(const index_binding_helper &) = delete; +#if __cplusplus < 201703 || _MSVC_LANG <= 201703 + index_binding_helper(index_binding_helper &&) = default; +#endif + typename std::conditional::type index; + T value; + }; + + template + auto named_parameter(const char *name, T &&arg) { + return index_binding_helper{name, std::forward(arg)}; + } + template + auto indexed_parameter(int index, T &&arg) { + return index_binding_helper{index, std::forward(arg)}; + } + + class row_iterator; + class database_binder { + + public: + // database_binder is not copyable + database_binder() = delete; + database_binder(const database_binder& other) = delete; + database_binder& operator=(const database_binder&) = delete; + + database_binder(database_binder&& other) : + _db(std::move(other._db)), + _stmt(std::move(other._stmt)), + _inx(other._inx), execution_started(other.execution_started) { } + + void execute(); + + std::string sql() { +#if SQLITE_VERSION_NUMBER >= 3014000 + auto sqlite_deleter = [](void *ptr) {sqlite3_free(ptr);}; + std::unique_ptr str(sqlite3_expanded_sql(_stmt.get()), sqlite_deleter); + return str ? str.get() : original_sql(); +#else + return original_sql(); +#endif + } + + std::string original_sql() { + return sqlite3_sql(_stmt.get()); + } + + void used(bool state) { + if(!state) { + // We may have to reset first if we haven't done so already: + _next_index(); + --_inx; + } + execution_started = state; + } + bool used() const { return execution_started; } + row_iterator begin(); + row_iterator end(); + + private: + std::shared_ptr _db; + std::unique_ptr _stmt; + utility::UncaughtExceptionDetector _has_uncaught_exception; + + int _inx; + + bool execution_started = false; + + int _next_index() { + if(execution_started && !_inx) { + sqlite3_reset(_stmt.get()); + sqlite3_clear_bindings(_stmt.get()); + } + return ++_inx; + } + + sqlite3_stmt* _prepare(u16str_ref sql) { + return _prepare(utility::utf16_to_utf8(sql)); + } + + sqlite3_stmt* _prepare(str_ref sql) { + int hresult; + sqlite3_stmt* tmp = nullptr; + const char *remaining; + hresult = sqlite3_prepare_v2(_db.get(), sql.data(), sql.length(), &tmp, &remaining); + if(hresult != SQLITE_OK) errors::throw_sqlite_error(hresult, sql, sqlite3_errmsg(_db.get())); + if(!std::all_of(remaining, sql.data() + sql.size(), [](char ch) {return std::isspace(ch);})) + throw errors::more_statements("Multiple semicolon separated statements are unsupported", sql); + return tmp; + } + + template friend database_binder& operator<<(database_binder& db, T&&); + template friend database_binder& operator<<(database_binder& db, index_binding_helper); + template friend database_binder& operator<<(database_binder& db, index_binding_helper); + friend void operator++(database_binder& db, int); + + public: + + database_binder(std::shared_ptr db, u16str_ref sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + database_binder(std::shared_ptr db, str_ref sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + ~database_binder() noexcept(false) { + /* Will be executed if no >>op is found, but not if an exception + is in mid flight */ + if(!used() && !_has_uncaught_exception && _stmt) { + execute(); + } + } + + friend class row_iterator; + }; + + class row_iterator { + public: + class value_type { + public: + value_type(database_binder *_binder): _binder(_binder) {}; + template + typename std::enable_if::value, value_type &>::type operator >>(T &result) { + result = get_col_from_db(_binder->_stmt.get(), next_index++, result_type()); + return *this; + } + template + value_type &operator >>(std::tuple& values) { + values = handle_tuple::type...>>(std::index_sequence_for()); + next_index += sizeof...(Types); + return *this; + } + template + value_type &operator >>(std::tuple&& values) { + return *this >> values; + } + template + operator std::tuple() { + std::tuple value; + *this >> value; + return value; + } + explicit operator bool() { + return sqlite3_column_count(_binder->_stmt.get()) >= next_index; + } + private: + template + Tuple handle_tuple(std::index_sequence) { + return Tuple( + get_col_from_db( + _binder->_stmt.get(), + next_index + Index, + result_type::type>())...); + } + database_binder *_binder; + int next_index = 0; + }; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + using iterator_category = std::input_iterator_tag; + + row_iterator() = default; + explicit row_iterator(database_binder &binder): _binder(&binder) { + _binder->_next_index(); + _binder->_inx = 0; + _binder->used(true); + ++*this; + } + + reference operator*() const { return value;} + pointer operator->() const { return std::addressof(**this); } + row_iterator &operator++() { + switch(int result = sqlite3_step(_binder->_stmt.get())) { + case SQLITE_ROW: + value = {_binder}; + break; + case SQLITE_DONE: + _binder = nullptr; + break; + default: + exceptions::throw_sqlite_error(result, _binder->sql(), sqlite3_errmsg(_binder->_db.get())); + } + return *this; + } + + friend inline bool operator ==(const row_iterator &a, const row_iterator &b) { + return a._binder == b._binder; + } + friend inline bool operator !=(const row_iterator &a, const row_iterator &b) { + return !(a==b); + } + + private: + database_binder *_binder = nullptr; + mutable value_type value{_binder}; // mutable, because `changing` the value is just reading it + }; + + inline row_iterator database_binder::begin() { + return row_iterator(*this); + } + + inline row_iterator database_binder::end() { + return row_iterator(); + } + + namespace detail { + template + void _extract_single_value(database_binder &binder, Callback call_back) { + auto iter = binder.begin(); + if(iter == binder.end()) + throw errors::no_rows("no rows to extract: exactly 1 row expected", binder.sql(), SQLITE_DONE); + + call_back(*iter); + + if(++iter != binder.end()) + throw errors::more_rows("not all rows extracted", binder.sql(), SQLITE_ROW); + } + } + inline void database_binder::execute() { + for(auto &&row : *this) + (void)row; + } + namespace detail { + template using void_t = void; + template + struct sqlite_direct_result : std::false_type {}; + template + struct sqlite_direct_result< + T, + void_t() >> std::declval())> + > : std::true_type {}; + } + template + inline typename std::enable_if::value>::type operator>>(database_binder &binder, Result&& value) { + detail::_extract_single_value(binder, [&value] (row_iterator::value_type &row) { + row >> std::forward(value); + }); + } + + template + inline typename std::enable_if::value>::type operator>>(database_binder &db_binder, Function&& func) { + using traits = utility::function_traits; + + for(auto &&row : db_binder) { + binder::run(row, func); + } + } + + template + inline decltype(auto) operator>>(database_binder &&binder, Result&& value) { + return binder >> std::forward(value); + } + + namespace sql_function_binder { + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + } + + enum class OpenFlags { + READONLY = SQLITE_OPEN_READONLY, + READWRITE = SQLITE_OPEN_READWRITE, + CREATE = SQLITE_OPEN_CREATE, + NOMUTEX = SQLITE_OPEN_NOMUTEX, + FULLMUTEX = SQLITE_OPEN_FULLMUTEX, + SHAREDCACHE = SQLITE_OPEN_SHAREDCACHE, + PRIVATECACH = SQLITE_OPEN_PRIVATECACHE, + URI = SQLITE_OPEN_URI + }; + inline OpenFlags operator|(const OpenFlags& a, const OpenFlags& b) { + return static_cast(static_cast(a) | static_cast(b)); + } + enum class Encoding { + ANY = SQLITE_ANY, + UTF8 = SQLITE_UTF8, + UTF16 = SQLITE_UTF16 + }; + struct sqlite_config { + OpenFlags flags = OpenFlags::READWRITE | OpenFlags::CREATE; + const char *zVfs = nullptr; + Encoding encoding = Encoding::ANY; + }; + + class database { + protected: + std::shared_ptr _db; + + public: + database(const std::string &db_name, const sqlite_config &config = {}): _db(nullptr) { + sqlite3* tmp = nullptr; + auto ret = sqlite3_open_v2(db_name.data(), &tmp, static_cast(config.flags), config.zVfs); + _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. + if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret, {}, sqlite3_errmsg(_db.get())); + sqlite3_extended_result_codes(_db.get(), true); + if(config.encoding == Encoding::UTF16) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(const std::u16string &db_name, const sqlite_config &config = {}): database(utility::utf16_to_utf8(db_name), config) { + if (config.encoding == Encoding::ANY) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(std::shared_ptr db): + _db(db) {} + + database_binder operator<<(str_ref sql) { + return database_binder(_db, sql); + } + + database_binder operator<<(u16str_ref sql) { + return database_binder(_db, sql); + } + + connection_type connection() const { return _db; } + + sqlite3_int64 last_insert_rowid() const { + return sqlite3_last_insert_rowid(_db.get()); + } + + int rows_modified() const { + return sqlite3_changes(_db.get()); + } + + template + void define(const std::string &name, Function&& func) { + typedef utility::function_traits traits; + + auto funcPtr = new auto(std::forward(func)); + if(int result = sqlite3_create_function_v2( + _db.get(), name.data(), traits::arity, SQLITE_UTF8, funcPtr, + sql_function_binder::scalar::type>, + nullptr, nullptr, [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result, {}, sqlite3_errmsg(_db.get())); + } + + template + void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { + typedef utility::function_traits traits; + using ContextType = typename std::remove_reference>::type; + + auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr, + sql_function_binder::step::type>, + sql_function_binder::final::type>, + [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result, {}, sqlite3_errmsg(_db.get())); + } + + }; + + template + class binder { + private: + template < + typename Function, + std::size_t Index + > + using nth_argument_type = typename utility::function_traits< + Function + >::template argument; + + public: + // `Boundary` needs to be defaulted to `Count` so that the `run` function + // template is not implicitly instantiated on class template instantiation. + // Look up section 14.7.1 _Implicit instantiation_ of the ISO C++14 Standard + // and the [dicussion](https://github.com/aminroosta/sqlite_modern_cpp/issues/8) + // on Github. + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) < Boundary), void>::type run( + row_iterator::value_type& row, + Function&& function, + Values&&... values + ) { + typename std::decay>::type value; + row >> value; + run(row, function, std::forward(values)..., std::move(value)); + } + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) == Boundary), void>::type run( + row_iterator::value_type&, + Function&& function, + Values&&... values + ) { + function(std::move(values)...); + } + }; + + // Some ppl are lazy so we have a operator for proper prep. statemant handling. + void inline operator++(database_binder& db, int) { db.execute(); } + + template database_binder &operator<<(database_binder& db, index_binding_helper val) { + db._next_index(); --db._inx; + int result = bind_col_in_db(db._stmt.get(), val.index, std::forward(val.value)); + if(result != SQLITE_OK) + exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); + return db; + } + + template database_binder &operator<<(database_binder& db, index_binding_helper val) { + db._next_index(); --db._inx; + int index = sqlite3_bind_parameter_index(db._stmt.get(), val.index); + if(!index) + throw errors::unknown_binding("The given binding name is not valid for this statement", db.sql()); + int result = bind_col_in_db(db._stmt.get(), index, std::forward(val.value)); + if(result != SQLITE_OK) + exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); + return db; + } + + template database_binder &operator<<(database_binder& db, T&& val) { + int result = bind_col_in_db(db._stmt.get(), db._next_index(), std::forward(val)); + if(result != SQLITE_OK) + exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); + return db; + } + // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) + template database_binder operator << (database_binder&& db, const T& val) { db << val; return std::move(db); } + template database_binder operator << (database_binder&& db, index_binding_helper val) { db << index_binding_helper{val.index, std::forward(val.value)}; return std::move(db); } + + namespace sql_function_binder { + template + struct AggregateCtxt { + T obj; + bool constructed = true; + }; + + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + try { + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + step(db, count, vals, ctxt->obj); + return; + } catch(const sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(const std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + using arg_type = typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits< + typename Functions::first_type + >::template argument + >::type + >::type; + + step( + db, + count, + vals, + std::forward(values)..., + get_val_from_db(vals[sizeof...(Values) - 1], result_type())); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + static_cast(sqlite3_user_data(db))->first(std::forward(values)...); + } + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + try { + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + store_result_in_db(db, + static_cast(sqlite3_user_data(db))->second(ctxt->obj)); + } catch(const sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(const std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + using arg_type = typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits::template argument + >::type + >::type; + + scalar( + db, + count, + vals, + std::forward(values)..., + get_val_from_db(vals[sizeof...(Values)], result_type())); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + try { + store_result_in_db(db, + (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); + } catch(const sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(const std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + } +} + diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/errors.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/errors.h new file mode 100644 index 00000000..14501dd7 --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/errors.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include + +#include + +namespace sqlite { + + class sqlite_exception: public std::runtime_error { + public: + sqlite_exception(const char* msg, str_ref sql, int code = -1): runtime_error(msg), code(code), sql(sql) {} + sqlite_exception(int code, str_ref sql, const char *msg = nullptr): runtime_error(msg ? msg : sqlite3_errstr(code)), code(code), sql(sql) {} + int get_code() const {return code & 0xFF;} + int get_extended_code() const {return code;} + std::string get_sql() const {return sql;} + const char *errstr() const {return code == -1 ? "Unknown error" : sqlite3_errstr(code);} + private: + int code; + std::string sql; + }; + + namespace errors { + //One more or less trivial derived error class for each SQLITE error. + //Note the following are not errors so have no classes: + //SQLITE_OK, SQLITE_NOTICE, SQLITE_WARNING, SQLITE_ROW, SQLITE_DONE + // + //Note these names are exact matches to the names of the SQLITE error codes. +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + class name: public sqlite_exception { using sqlite_exception::sqlite_exception; };\ + derived +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + class base ## _ ## sub: public base { using base::base; }; +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + + //Some additional errors are here for the C++ interface + class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement + class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class unknown_binding: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + + static void throw_sqlite_error(const int& error_code, str_ref sql = "", const char *errmsg = nullptr) { + switch(error_code & 0xFF) { +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + case SQLITE_ ## NAME: switch(error_code) { \ + derived \ + case SQLITE_ ## NAME: \ + default: throw name(error_code, sql); \ + } + +#if SQLITE_VERSION_NUMBER < 3010000 +#define SQLITE_IOERR_VNODE (SQLITE_IOERR | (27<<8)) +#define SQLITE_IOERR_AUTH (SQLITE_IOERR | (28<<8)) +#define SQLITE_AUTH_USER (SQLITE_AUTH | (1<<8)) +#endif + +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + case SQLITE_ ## BASE ## _ ## SUB: throw base ## _ ## sub(error_code, sql, errmsg); +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + default: throw sqlite_exception(error_code, sql, errmsg); + } + } + } + namespace exceptions = errors; +} diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/lists/error_codes.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/lists/error_codes.h new file mode 100644 index 00000000..d8804f5b --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/lists/error_codes.h @@ -0,0 +1,88 @@ +SQLITE_MODERN_CPP_ERROR_CODE(ERROR,error,) +SQLITE_MODERN_CPP_ERROR_CODE(INTERNAL,internal,) +SQLITE_MODERN_CPP_ERROR_CODE(PERM,perm,) +SQLITE_MODERN_CPP_ERROR_CODE(ABORT,abort, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(ABORT,ROLLBACK,abort,rollback) +) +SQLITE_MODERN_CPP_ERROR_CODE(BUSY,busy, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,RECOVERY,busy,recovery) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BUSY,SNAPSHOT,busy,snapshot) +) +SQLITE_MODERN_CPP_ERROR_CODE(LOCKED,locked, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(LOCKED,SHAREDCACHE,locked,sharedcache) +) +SQLITE_MODERN_CPP_ERROR_CODE(NOMEM,nomem,) +SQLITE_MODERN_CPP_ERROR_CODE(READONLY,readonly,) +SQLITE_MODERN_CPP_ERROR_CODE(INTERRUPT,interrupt,) +SQLITE_MODERN_CPP_ERROR_CODE(IOERR,ioerr, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,READ,ioerr,read) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHORT_READ,ioerr,short_read) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,WRITE,ioerr,write) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSYNC,ioerr,fsync) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_FSYNC,ioerr,dir_fsync) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,TRUNCATE,ioerr,truncate) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,FSTAT,ioerr,fstat) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,UNLOCK,ioerr,unlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,RDLOCK,ioerr,rdlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE,ioerr,delete) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,BLOCKED,ioerr,blocked) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,NOMEM,ioerr,nomem) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,ACCESS,ioerr,access) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CHECKRESERVEDLOCK,ioerr,checkreservedlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,LOCK,ioerr,lock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CLOSE,ioerr,close) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DIR_CLOSE,ioerr,dir_close) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMOPEN,ioerr,shmopen) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMSIZE,ioerr,shmsize) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMLOCK,ioerr,shmlock) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SHMMAP,ioerr,shmmap) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,SEEK,ioerr,seek) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,DELETE_NOENT,ioerr,delete_noent) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,MMAP,ioerr,mmap) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,GETTEMPPATH,ioerr,gettemppath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,CONVPATH,ioerr,convpath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,VNODE,ioerr,vnode) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(IOERR,AUTH,ioerr,auth) +) +SQLITE_MODERN_CPP_ERROR_CODE(CORRUPT,corrupt, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CORRUPT,VTAB,corrupt,vtab) +) +SQLITE_MODERN_CPP_ERROR_CODE(NOTFOUND,notfound,) +SQLITE_MODERN_CPP_ERROR_CODE(FULL,full,) +SQLITE_MODERN_CPP_ERROR_CODE(CANTOPEN,cantopen, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,NOTEMPDIR,cantopen,notempdir) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,ISDIR,cantopen,isdir) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,FULLPATH,cantopen,fullpath) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CANTOPEN,CONVPATH,cantopen,convpath) +) +SQLITE_MODERN_CPP_ERROR_CODE(PROTOCOL,protocol,) +SQLITE_MODERN_CPP_ERROR_CODE(EMPTY,empty,) +SQLITE_MODERN_CPP_ERROR_CODE(SCHEMA,schema,) +SQLITE_MODERN_CPP_ERROR_CODE(TOOBIG,toobig,) +SQLITE_MODERN_CPP_ERROR_CODE(CONSTRAINT,constraint, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,CHECK,constraint,check) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,COMMITHOOK,constraint,commithook) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FOREIGNKEY,constraint,foreignkey) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,FUNCTION,constraint,function) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,NOTNULL,constraint,notnull) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,PRIMARYKEY,constraint,primarykey) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,TRIGGER,constraint,trigger) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,UNIQUE,constraint,unique) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,VTAB,constraint,vtab) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(CONSTRAINT,ROWID,constraint,rowid) +) +SQLITE_MODERN_CPP_ERROR_CODE(MISMATCH,mismatch,) +SQLITE_MODERN_CPP_ERROR_CODE(MISUSE,misuse,) +SQLITE_MODERN_CPP_ERROR_CODE(NOLFS,nolfs,) +SQLITE_MODERN_CPP_ERROR_CODE(AUTH,auth, +) +SQLITE_MODERN_CPP_ERROR_CODE(FORMAT,format,) +SQLITE_MODERN_CPP_ERROR_CODE(RANGE,range,) +SQLITE_MODERN_CPP_ERROR_CODE(NOTADB,notadb,) +SQLITE_MODERN_CPP_ERROR_CODE(NOTICE,notice, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_WAL,notice,recover_wal) + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(NOTICE,RECOVER_ROLLBACK,notice,recover_rollback) +) +SQLITE_MODERN_CPP_ERROR_CODE(WARNING,warning, + SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(WARNING,AUTOINDEX,warning,autoindex) +) diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/log.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/log.h new file mode 100644 index 00000000..5abe2933 --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/log.h @@ -0,0 +1,101 @@ +#include "errors.h" + +#include + +#include +#include +#include + +namespace sqlite { + namespace detail { + template + using void_t = void; + template + struct is_callable : std::false_type {}; + template + struct is_callable()(std::declval()...))>> : std::true_type {}; + template + class FunctorOverload: public Functor, public FunctorOverload { + public: + template + FunctorOverload(Functor1 &&functor, Remaining &&... remaining): + Functor(std::forward(functor)), + FunctorOverload(std::forward(remaining)...) {} + using Functor::operator(); + using FunctorOverload::operator(); + }; + template + class FunctorOverload: public Functor { + public: + template + FunctorOverload(Functor1 &&functor): + Functor(std::forward(functor)) {} + using Functor::operator(); + }; + template + class WrapIntoFunctor: public Functor { + public: + template + WrapIntoFunctor(Functor1 &&functor): + Functor(std::forward(functor)) {} + using Functor::operator(); + }; + template + class WrapIntoFunctor { + ReturnType(*ptr)(Arguments...); + public: + WrapIntoFunctor(ReturnType(*ptr)(Arguments...)): ptr(ptr) {} + ReturnType operator()(Arguments... arguments) { return (*ptr)(std::forward(arguments)...); } + }; + inline void store_error_log_data_pointer(std::shared_ptr ptr) { + static std::shared_ptr stored; + stored = std::move(ptr); + } + template + std::shared_ptr::type> make_shared_inferred(T &&t) { + return std::make_shared::type>(std::forward(t)); + } + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler); + template + typename std::enable_if::value>::type + error_log(Handler &&handler); + template + typename std::enable_if=2>::type + error_log(Handler &&...handler) { + return error_log(detail::FunctorOverload::type>...>(std::forward(handler)...)); + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler) { + return error_log(std::forward(handler), [](const sqlite_exception&) {}); + } + template + typename std::enable_if::value>::type + error_log(Handler &&handler) { + auto ptr = detail::make_shared_inferred([handler = std::forward(handler)](int error_code, const char *errstr) mutable { + switch(error_code & 0xFF) { +#define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ + case SQLITE_ ## NAME: switch(error_code) { \ + derived \ + default: handler(errors::name(errstr, "", error_code)); \ + };break; +#define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ + case SQLITE_ ## BASE ## _ ## SUB: \ + handler(errors::base ## _ ## sub(errstr, "", error_code)); \ + break; +#include "lists/error_codes.h" +#undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED +#undef SQLITE_MODERN_CPP_ERROR_CODE + default: handler(sqlite_exception(errstr, "", error_code)); \ + } + }); + + sqlite3_config(SQLITE_CONFIG_LOG, static_cast([](void *functor, int error_code, const char *errstr) { + (*static_cast(functor))(error_code, errstr); + }), ptr.get()); + detail::store_error_log_data_pointer(std::move(ptr)); + } +} diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/sqlcipher.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/sqlcipher.h new file mode 100644 index 00000000..da0f0189 --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/sqlcipher.h @@ -0,0 +1,44 @@ +#pragma once + +#ifndef SQLITE_HAS_CODEC +#define SQLITE_HAS_CODEC +#endif + +#include "../sqlite_modern_cpp.h" + +namespace sqlite { + struct sqlcipher_config : public sqlite_config { + std::string key; + }; + + class sqlcipher_database : public database { + public: + sqlcipher_database(std::string db, const sqlcipher_config &config): database(db, config) { + set_key(config.key); + } + + sqlcipher_database(std::u16string db, const sqlcipher_config &config): database(db, config) { + set_key(config.key); + } + + void set_key(const std::string &key) { + if(auto ret = sqlite3_key(_db.get(), key.data(), key.size())) + errors::throw_sqlite_error(ret); + } + + void set_key(const std::string &key, const std::string &db_name) { + if(auto ret = sqlite3_key_v2(_db.get(), db_name.c_str(), key.data(), key.size())) + errors::throw_sqlite_error(ret); + } + + void rekey(const std::string &new_key) { + if(auto ret = sqlite3_rekey(_db.get(), new_key.data(), new_key.size())) + errors::throw_sqlite_error(ret); + } + + void rekey(const std::string &new_key, const std::string &db_name) { + if(auto ret = sqlite3_rekey_v2(_db.get(), db_name.c_str(), new_key.data(), new_key.size())) + errors::throw_sqlite_error(ret); + } + }; +} diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/type_wrapper.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/type_wrapper.h new file mode 100644 index 00000000..7b12a36b --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/type_wrapper.h @@ -0,0 +1,433 @@ +#pragma once + +#include +#include +#include +#include +#ifdef __has_include +#if (__cplusplus >= 201703 || _MSVC_LANG >= 201703) && __has_include() +#define MODERN_SQLITE_STRINGVIEW_SUPPORT +#endif +#endif +#ifdef __has_include +#if (__cplusplus > 201402 || _MSVC_LANG > 201402) && __has_include() +#define MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#elif __has_include() && __apple_build_version__ < 11000000 +#define MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT +#endif +#endif + +#ifdef __has_include +#if (__cplusplus > 201402 || _MSVC_LANG > 201402) && __has_include() +#define MODERN_SQLITE_STD_VARIANT_SUPPORT +#endif +#endif + +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#include +#endif + +#ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT +#include +#define MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#endif + +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT +#include +#endif +#ifdef MODERN_SQLITE_STRINGVIEW_SUPPORT +#include +namespace sqlite +{ + typedef const std::string_view str_ref; + typedef const std::u16string_view u16str_ref; +} +#else +namespace sqlite +{ + typedef const std::string& str_ref; + typedef const std::u16string& u16str_ref; +} +#endif +#include +#include "errors.h" + +namespace sqlite { + template + struct has_sqlite_type : std::false_type {}; + + template + using is_sqlite_value = std::integral_constant::value + || has_sqlite_type::value + || has_sqlite_type::value + || has_sqlite_type::value + || has_sqlite_type::value + >; + + template + struct has_sqlite_type : has_sqlite_type {}; + template + struct has_sqlite_type : has_sqlite_type {}; + template + struct has_sqlite_type : has_sqlite_type {}; + + template + struct result_type { + using type = T; + constexpr result_type() = default; + template::value>> + constexpr result_type(result_type) { } + }; + + // int + template<> + struct has_sqlite_type : std::true_type {}; + + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const int& val) { + return sqlite3_bind_int(stmt, inx, val); + } + inline void store_result_in_db(sqlite3_context* db, const int& val) { + sqlite3_result_int(db, val); + } + inline int get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_int(stmt, inx); + } + inline int get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_int(value); + } + + // sqlite_int64 + template<> + struct has_sqlite_type : std::true_type {}; + + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const sqlite_int64& val) { + return sqlite3_bind_int64(stmt, inx, val); + } + inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) { + sqlite3_result_int64(db, val); + } + inline sqlite_int64 get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_int64(stmt, inx); + } + inline sqlite3_int64 get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_int64(value); + } + + // float + template<> + struct has_sqlite_type : std::true_type {}; + + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const float& val) { + return sqlite3_bind_double(stmt, inx, double(val)); + } + inline void store_result_in_db(sqlite3_context* db, const float& val) { + sqlite3_result_double(db, val); + } + inline float get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_double(stmt, inx); + } + inline float get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_double(value); + } + + // double + template<> + struct has_sqlite_type : std::true_type {}; + + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const double& val) { + return sqlite3_bind_double(stmt, inx, val); + } + inline void store_result_in_db(sqlite3_context* db, const double& val) { + sqlite3_result_double(db, val); + } + inline double get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_double(stmt, inx); + } + inline double get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_double(value); + } + + /* for nullptr support */ + template<> + struct has_sqlite_type : std::true_type {}; + + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, std::nullptr_t) { + return sqlite3_bind_null(stmt, inx); + } + inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) { + sqlite3_result_null(db); + } + +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + template<> + struct has_sqlite_type : std::true_type {}; + + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, std::monostate) { + return sqlite3_bind_null(stmt, inx); + } + inline void store_result_in_db(sqlite3_context* db, std::monostate) { + sqlite3_result_null(db); + } + inline std::monostate get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return std::monostate(); + } + inline std::monostate get_val_from_db(sqlite3_value *value, result_type) { + return std::monostate(); + } +#endif + + // str_ref + template<> + struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, str_ref val) { + return sqlite3_bind_text(stmt, inx, val.data(), val.length(), SQLITE_TRANSIENT); + } + + // Convert char* to string_view to trigger op<<(..., const str_ref ) + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const char(&STR)[N]) { + return sqlite3_bind_text(stmt, inx, &STR[0], N-1, SQLITE_TRANSIENT); + } + + inline std::string get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + if ( sqlite3_column_type(stmt, inx) == SQLITE_NULL ) { + return std::string(); + } + char const * ptr = reinterpret_cast(sqlite3_column_text(stmt, inx)); + // call sqlite3_column_text explicitely before sqlite3_column_bytes: it may convert the value to text + return std::string(ptr, sqlite3_column_bytes(stmt, inx)); + } + inline std::string get_val_from_db(sqlite3_value *value, result_type) { + if ( sqlite3_value_type(value) == SQLITE_NULL ) { + return std::string(); + } + char const * ptr = reinterpret_cast(sqlite3_value_text(value)); + // call sqlite3_column_text explicitely before sqlite3_column_bytes: it may convert the value to text + return std::string(ptr, sqlite3_value_bytes(value)); + } + + inline void store_result_in_db(sqlite3_context* db, str_ref val) { + sqlite3_result_text(db, val.data(), val.length(), SQLITE_TRANSIENT); + } + // u16str_ref + template<> + struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, u16str_ref val) { + return sqlite3_bind_text16(stmt, inx, val.data(), sizeof(char16_t) * val.length(), SQLITE_TRANSIENT); + } + + // Convert char* to string_view to trigger op<<(..., const str_ref ) + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const char16_t(&STR)[N]) { + return sqlite3_bind_text16(stmt, inx, &STR[0], sizeof(char16_t) * (N-1), SQLITE_TRANSIENT); + } + + inline std::u16string get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + if ( sqlite3_column_type(stmt, inx) == SQLITE_NULL ) { + return std::u16string(); + } + char16_t const * ptr = reinterpret_cast(sqlite3_column_text16(stmt, inx)); + // call sqlite3_column_text16 explicitely before sqlite3_column_bytes16: it may convert the value to text + return std::u16string(ptr, sqlite3_column_bytes16(stmt, inx)); + } + inline std::u16string get_val_from_db(sqlite3_value *value, result_type) { + if ( sqlite3_value_type(value) == SQLITE_NULL ) { + return std::u16string(); + } + char16_t const * ptr = reinterpret_cast(sqlite3_value_text16(value)); + return std::u16string(ptr, sqlite3_value_bytes16(value)); + } + + inline void store_result_in_db(sqlite3_context* db, u16str_ref val) { + sqlite3_result_text16(db, val.data(), sizeof(char16_t) * val.length(), SQLITE_TRANSIENT); + } + + // Other integer types + template + struct has_sqlite_type::value>::type> : std::true_type {}; + + template::value>::type> + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const Integral& val) { + return bind_col_in_db(stmt, inx, static_cast(val)); + } + template::type>> + inline void store_result_in_db(sqlite3_context* db, const Integral& val) { + store_result_in_db(db, static_cast(val)); + } + template::value>::type> + inline Integral get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return get_col_from_db(stmt, inx, result_type()); + } + template::value>::type> + inline Integral get_val_from_db(sqlite3_value *value, result_type) { + return get_val_from_db(value, result_type()); + } + + // vector + template + struct has_sqlite_type, SQLITE_BLOB, void> : std::true_type {}; + + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + return sqlite3_bind_blob(stmt, inx, buf, bytes, SQLITE_TRANSIENT); + } + template inline void store_result_in_db(sqlite3_context* db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT); + } + template inline std::vector get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + return {}; + } + T const* buf = reinterpret_cast(sqlite3_column_blob(stmt, inx)); + int bytes = sqlite3_column_bytes(stmt, inx); + return std::vector(buf, buf + bytes/sizeof(T)); + } + template inline std::vector get_val_from_db(sqlite3_value *value, result_type>) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + return {}; + } + T const* buf = reinterpret_cast(sqlite3_value_blob(value)); + int bytes = sqlite3_value_bytes(value); + return std::vector(buf, buf + bytes/sizeof(T)); + } + + /* for unique_ptr support */ + template + struct has_sqlite_type, Type, void> : has_sqlite_type {}; + template + struct has_sqlite_type, SQLITE_NULL, void> : std::true_type {}; + + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const std::unique_ptr& val) { + return val ? bind_col_in_db(stmt, inx, *val) : bind_col_in_db(stmt, inx, nullptr); + } + template inline std::unique_ptr get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + return nullptr; + } + return std::make_unique(get_col_from_db(stmt, inx, result_type())); + } + template inline std::unique_ptr get_val_from_db(sqlite3_value *value, result_type>) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + return nullptr; + } + return std::make_unique(get_val_from_db(value, result_type())); + } + + // std::optional support for NULL values +#ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT +#ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + template + using optional = std::experimental::optional; +#else + template + using optional = std::optional; +#endif + + template + struct has_sqlite_type, Type, void> : has_sqlite_type {}; + template + struct has_sqlite_type, SQLITE_NULL, void> : std::true_type {}; + + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const optional& val) { + return val ? bind_col_in_db(stmt, inx, *val) : bind_col_in_db(stmt, inx, nullptr); + } + template inline void store_result_in_db(sqlite3_context* db, const optional& val) { + if(val) + store_result_in_db(db, *val); + else + sqlite3_result_null(db); + } + + template inline optional get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + return std::experimental::nullopt; + } + return std::experimental::make_optional(get_col_from_db(stmt, inx, result_type())); + #else + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + return std::nullopt; + } + return std::make_optional(get_col_from_db(stmt, inx, result_type())); + #endif + } + template inline optional get_val_from_db(sqlite3_value *value, result_type>) { + #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + if(sqlite3_value_type(value) == SQLITE_NULL) { + return std::experimental::nullopt; + } + return std::experimental::make_optional(get_val_from_db(value, result_type())); + #else + if(sqlite3_value_type(value) == SQLITE_NULL) { + return std::nullopt; + } + return std::make_optional(get_val_from_db(value, result_type())); + #endif + } +#endif + +#ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT + namespace detail { + template + struct tag_trait : U { using tag = T; }; + } + + template + struct has_sqlite_type, Type, void> : std::disjunction>...> {}; + + namespace detail { + template, Type>> + inline std::variant variant_select_type(Callback &&callback) { + if constexpr(first_compatible::value) + return callback(result_type()); + else + throw errors::mismatch("The value is unsupported by this variant.", "", SQLITE_MISMATCH); + } + template inline decltype(auto) variant_select(int type, Callback &&callback) { + switch(type) { + case SQLITE_NULL: + return variant_select_type(std::forward(callback)); + case SQLITE_INTEGER: + return variant_select_type(std::forward(callback)); + case SQLITE_FLOAT: + return variant_select_type(std::forward(callback)); + case SQLITE_TEXT: + return variant_select_type(std::forward(callback)); + case SQLITE_BLOB: + return variant_select_type(std::forward(callback)); + } +#ifdef _MSC_VER + __assume(false); +#else + __builtin_unreachable(); +#endif + } + } + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const std::variant& val) { + return std::visit([&](auto &&opt) {return bind_col_in_db(stmt, inx, std::forward(opt));}, val); + } + template inline void store_result_in_db(sqlite3_context* db, const std::variant& val) { + std::visit([&](auto &&opt) {store_result_in_db(db, std::forward(opt));}, val); + } + template inline std::variant get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { + return detail::variant_select(sqlite3_column_type(stmt, inx), [&](auto v) { + return std::variant(std::in_place_type, get_col_from_db(stmt, inx, v)); + }); + } + template inline std::variant get_val_from_db(sqlite3_value *value, result_type>) { + return detail::variant_select(sqlite3_value_type(value), [&](auto v) { + return std::variant(std::in_place_type, get_val_from_db(value, v)); + }); + } +#endif +} diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/function_traits.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/function_traits.h new file mode 100644 index 00000000..f629aa09 --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/function_traits.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +namespace sqlite { + namespace utility { + + template struct function_traits; + + template + struct function_traits : public function_traits< + decltype(&std::remove_reference::type::operator()) + > { }; + + template < + typename ClassType, + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(ClassType::*)(Arguments...) const + > : function_traits { }; + + /* support the non-const operator () + * this will work with user defined functors */ + template < + typename ClassType, + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(ClassType::*)(Arguments...) + > : function_traits { }; + + template < + typename ReturnType, + typename... Arguments + > + struct function_traits< + ReturnType(*)(Arguments...) + > { + typedef ReturnType result_type; + + using argument_tuple = std::tuple; + template + using argument = typename std::tuple_element< + Index, + argument_tuple + >::type; + + static const std::size_t arity = sizeof...(Arguments); + }; + + } +} diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/uncaught_exceptions.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/uncaught_exceptions.h new file mode 100644 index 00000000..65997e00 --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/uncaught_exceptions.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +// Consider that std::uncaught_exceptions is available if explicitly indicated +// by the standard library, if compiler advertises full C++17 support or, as a +// special case, for MSVS 2015+ (which doesn't define __cplusplus correctly by +// default as of 2017.7 version and couldn't do it at all until it). +#ifndef MODERN_SQLITE_UNCAUGHT_EXCEPTIONS_SUPPORT +#ifdef __cpp_lib_uncaught_exceptions +#define MODERN_SQLITE_UNCAUGHT_EXCEPTIONS_SUPPORT +#elif __cplusplus >= 201703L +#define MODERN_SQLITE_UNCAUGHT_EXCEPTIONS_SUPPORT +#elif defined(_MSC_VER) && _MSC_VER >= 1900 +#define MODERN_SQLITE_UNCAUGHT_EXCEPTIONS_SUPPORT +#endif +#endif + +namespace sqlite { + namespace utility { +#ifdef MODERN_SQLITE_UNCAUGHT_EXCEPTIONS_SUPPORT + class UncaughtExceptionDetector { + public: + operator bool() { + return count != std::uncaught_exceptions(); + } + private: + int count = std::uncaught_exceptions(); + }; +#else + class UncaughtExceptionDetector { + public: + operator bool() { + return std::uncaught_exception(); + } + }; +#endif + } +} diff --git a/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/utf16_utf8.h b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/utf16_utf8.h new file mode 100644 index 00000000..340a26eb --- /dev/null +++ b/examples/server/sqlite_modern_cpp/hdr/sqlite_modern_cpp/utility/utf16_utf8.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +#include "../errors.h" + +namespace sqlite { + namespace utility { + inline std::string utf16_to_utf8(u16str_ref input) { + struct : std::codecvt { + } codecvt; + std::mbstate_t state{}; + std::string result((std::max)(input.size() * 3 / 2, std::size_t(4)), '\0'); + const char16_t *remaining_input = input.data(); + std::size_t produced_output = 0; + while(true) { + char *used_output; + switch(codecvt.out(state, remaining_input, input.data() + input.size(), + remaining_input, &result[produced_output], + &result[result.size() - 1] + 1, used_output)) { + case std::codecvt_base::ok: + result.resize(used_output - result.data()); + return result; + case std::codecvt_base::noconv: + // This should be unreachable + case std::codecvt_base::error: + throw errors::invalid_utf16("Invalid UTF-16 input", ""); + case std::codecvt_base::partial: + if(used_output == result.data() + produced_output) + throw errors::invalid_utf16("Unexpected end of input", ""); + produced_output = used_output - result.data(); + result.resize( + result.size() + + (std::max)((input.data() + input.size() - remaining_input) * 3 / 2, + std::ptrdiff_t(4))); + } + } + } + } // namespace utility +} // namespace sqlite