#pragma once #include #include #include #include #include #include #define MODERN_SQLITE_VERSION 3002008 #include #include "sqlite_modern_cpp/type_wrapper.h" #include "sqlite_modern_cpp/errors.h" #include "sqlite_modern_cpp/utility/function_traits.h" #include "sqlite_modern_cpp/utility/uncaught_exceptions.h" #include "sqlite_modern_cpp/utility/utf16_utf8.h" namespace sqlite { class database; class database_binder; template class binder; typedef std::shared_ptr connection_type; template struct index_binding_helper { index_binding_helper(const index_binding_helper &) = delete; #if __cplusplus < 201703 || _MSVC_LANG <= 201703 index_binding_helper(index_binding_helper &&) = default; #endif typename std::conditional::type index; T value; }; template auto named_parameter(const char *name, T &&arg) { return index_binding_helper{name, std::forward(arg)}; } template auto indexed_parameter(int index, T &&arg) { return index_binding_helper{index, std::forward(arg)}; } class row_iterator; class database_binder { public: // database_binder is not copyable database_binder() = delete; database_binder(const database_binder& other) = delete; database_binder& operator=(const database_binder&) = delete; database_binder(database_binder&& other) : _db(std::move(other._db)), _stmt(std::move(other._stmt)), _inx(other._inx), execution_started(other.execution_started) { } void execute(); std::string sql() { #if SQLITE_VERSION_NUMBER >= 3014000 auto sqlite_deleter = [](void *ptr) {sqlite3_free(ptr);}; std::unique_ptr str(sqlite3_expanded_sql(_stmt.get()), sqlite_deleter); return str ? str.get() : original_sql(); #else return original_sql(); #endif } std::string original_sql() { return sqlite3_sql(_stmt.get()); } void used(bool state) { if(!state) { // We may have to reset first if we haven't done so already: _next_index(); --_inx; } execution_started = state; } bool used() const { return execution_started; } row_iterator begin(); row_iterator end(); private: std::shared_ptr _db; std::unique_ptr _stmt; utility::UncaughtExceptionDetector _has_uncaught_exception; int _inx; bool execution_started = false; int _next_index() { if(execution_started && !_inx) { sqlite3_reset(_stmt.get()); sqlite3_clear_bindings(_stmt.get()); } return ++_inx; } sqlite3_stmt* _prepare(u16str_ref sql) { return _prepare(utility::utf16_to_utf8(sql)); } sqlite3_stmt* _prepare(str_ref sql) { int hresult; sqlite3_stmt* tmp = nullptr; const char *remaining; hresult = sqlite3_prepare_v2(_db.get(), sql.data(), sql.length(), &tmp, &remaining); if(hresult != SQLITE_OK) errors::throw_sqlite_error(hresult, sql, sqlite3_errmsg(_db.get())); if(!std::all_of(remaining, sql.data() + sql.size(), [](char ch) {return std::isspace(ch);})) throw errors::more_statements("Multiple semicolon separated statements are unsupported", sql); return tmp; } template friend database_binder& operator<<(database_binder& db, T&&); template friend database_binder& operator<<(database_binder& db, index_binding_helper); template friend database_binder& operator<<(database_binder& db, index_binding_helper); friend void operator++(database_binder& db, int); public: database_binder(std::shared_ptr db, u16str_ref sql): _db(db), _stmt(_prepare(sql), sqlite3_finalize), _inx(0) { } database_binder(std::shared_ptr db, str_ref sql): _db(db), _stmt(_prepare(sql), sqlite3_finalize), _inx(0) { } ~database_binder() noexcept(false) { /* Will be executed if no >>op is found, but not if an exception is in mid flight */ if(!used() && !_has_uncaught_exception && _stmt) { execute(); } } friend class row_iterator; }; class row_iterator { public: class value_type { public: value_type(database_binder *_binder): _binder(_binder) {}; template typename std::enable_if::value, value_type &>::type operator >>(T &result) { result = get_col_from_db(_binder->_stmt.get(), next_index++, result_type()); return *this; } template value_type &operator >>(std::tuple& values) { values = handle_tuple::type...>>(std::index_sequence_for()); next_index += sizeof...(Types); return *this; } template value_type &operator >>(std::tuple&& values) { return *this >> values; } template operator std::tuple() { std::tuple value; *this >> value; return value; } explicit operator bool() { return sqlite3_column_count(_binder->_stmt.get()) >= next_index; } private: template Tuple handle_tuple(std::index_sequence) { return Tuple( get_col_from_db( _binder->_stmt.get(), next_index + Index, result_type::type>())...); } database_binder *_binder; int next_index = 0; }; using difference_type = std::ptrdiff_t; using pointer = value_type*; using reference = value_type&; using iterator_category = std::input_iterator_tag; row_iterator() = default; explicit row_iterator(database_binder &binder): _binder(&binder) { _binder->_next_index(); _binder->_inx = 0; _binder->used(true); ++*this; } reference operator*() const { return value;} pointer operator->() const { return std::addressof(**this); } row_iterator &operator++() { switch(int result = sqlite3_step(_binder->_stmt.get())) { case SQLITE_ROW: value = {_binder}; break; case SQLITE_DONE: _binder = nullptr; break; default: exceptions::throw_sqlite_error(result, _binder->sql(), sqlite3_errmsg(_binder->_db.get())); } return *this; } friend inline bool operator ==(const row_iterator &a, const row_iterator &b) { return a._binder == b._binder; } friend inline bool operator !=(const row_iterator &a, const row_iterator &b) { return !(a==b); } private: database_binder *_binder = nullptr; mutable value_type value{_binder}; // mutable, because `changing` the value is just reading it }; inline row_iterator database_binder::begin() { return row_iterator(*this); } inline row_iterator database_binder::end() { return row_iterator(); } namespace detail { template void _extract_single_value(database_binder &binder, Callback call_back) { auto iter = binder.begin(); if(iter == binder.end()) throw errors::no_rows("no rows to extract: exactly 1 row expected", binder.sql(), SQLITE_DONE); call_back(*iter); if(++iter != binder.end()) throw errors::more_rows("not all rows extracted", binder.sql(), SQLITE_ROW); } } inline void database_binder::execute() { for(auto &&row : *this) (void)row; } namespace detail { template using void_t = void; template struct sqlite_direct_result : std::false_type {}; template struct sqlite_direct_result< T, void_t() >> std::declval())> > : std::true_type {}; } template inline typename std::enable_if::value>::type operator>>(database_binder &binder, Result&& value) { detail::_extract_single_value(binder, [&value] (row_iterator::value_type &row) { row >> std::forward(value); }); } template inline typename std::enable_if::value>::type operator>>(database_binder &db_binder, Function&& func) { using traits = utility::function_traits; for(auto &&row : db_binder) { binder::run(row, func); } } template inline decltype(auto) operator>>(database_binder &&binder, Result&& value) { return binder >> std::forward(value); } namespace sql_function_binder { template< typename ContextType, std::size_t Count, typename Functions > inline void step( sqlite3_context* db, int count, sqlite3_value** vals ); template< std::size_t Count, typename Functions, typename... Values > inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( sqlite3_context* db, int count, sqlite3_value** vals, Values&&... values ); template< std::size_t Count, typename Functions, typename... Values > inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( sqlite3_context* db, int, sqlite3_value**, Values&&... values ); template< typename ContextType, typename Functions > inline void final(sqlite3_context* db); template< std::size_t Count, typename Function, typename... Values > inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( sqlite3_context* db, int count, sqlite3_value** vals, Values&&... values ); template< std::size_t Count, typename Function, typename... Values > inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( sqlite3_context* db, int, sqlite3_value**, Values&&... values ); } enum class OpenFlags { READONLY = SQLITE_OPEN_READONLY, READWRITE = SQLITE_OPEN_READWRITE, CREATE = SQLITE_OPEN_CREATE, NOMUTEX = SQLITE_OPEN_NOMUTEX, FULLMUTEX = SQLITE_OPEN_FULLMUTEX, SHAREDCACHE = SQLITE_OPEN_SHAREDCACHE, PRIVATECACH = SQLITE_OPEN_PRIVATECACHE, URI = SQLITE_OPEN_URI }; inline OpenFlags operator|(const OpenFlags& a, const OpenFlags& b) { return static_cast(static_cast(a) | static_cast(b)); } enum class Encoding { ANY = SQLITE_ANY, UTF8 = SQLITE_UTF8, UTF16 = SQLITE_UTF16 }; struct sqlite_config { OpenFlags flags = OpenFlags::READWRITE | OpenFlags::CREATE; const char *zVfs = nullptr; Encoding encoding = Encoding::ANY; }; class database { protected: std::shared_ptr _db; public: database(const std::string &db_name, const sqlite_config &config = {}): _db(nullptr) { sqlite3* tmp = nullptr; auto ret = sqlite3_open_v2(db_name.data(), &tmp, static_cast(config.flags), config.zVfs); _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret, {}, sqlite3_errmsg(_db.get())); sqlite3_extended_result_codes(_db.get(), true); if(config.encoding == Encoding::UTF16) *this << R"(PRAGMA encoding = "UTF-16";)"; } database(const std::u16string &db_name, const sqlite_config &config = {}): database(utility::utf16_to_utf8(db_name), config) { if (config.encoding == Encoding::ANY) *this << R"(PRAGMA encoding = "UTF-16";)"; } database(std::shared_ptr db): _db(db) {} database_binder operator<<(str_ref sql) { return database_binder(_db, sql); } database_binder operator<<(u16str_ref sql) { return database_binder(_db, sql); } connection_type connection() const { return _db; } sqlite3_int64 last_insert_rowid() const { return sqlite3_last_insert_rowid(_db.get()); } int rows_modified() const { return sqlite3_changes(_db.get()); } template void define(const std::string &name, Function&& func) { typedef utility::function_traits traits; auto funcPtr = new auto(std::forward(func)); if(int result = sqlite3_create_function_v2( _db.get(), name.data(), traits::arity, SQLITE_UTF8, funcPtr, sql_function_binder::scalar::type>, nullptr, nullptr, [](void* ptr){ delete static_cast(ptr); })) errors::throw_sqlite_error(result, {}, sqlite3_errmsg(_db.get())); } template void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { typedef utility::function_traits traits; using ContextType = typename std::remove_reference>::type; auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); if(int result = sqlite3_create_function_v2( _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr, sql_function_binder::step::type>, sql_function_binder::final::type>, [](void* ptr){ delete static_cast(ptr); })) errors::throw_sqlite_error(result, {}, sqlite3_errmsg(_db.get())); } }; template class binder { private: template < typename Function, std::size_t Index > using nth_argument_type = typename utility::function_traits< Function >::template argument; public: // `Boundary` needs to be defaulted to `Count` so that the `run` function // template is not implicitly instantiated on class template instantiation. // Look up section 14.7.1 _Implicit instantiation_ of the ISO C++14 Standard // and the [dicussion](https://github.com/aminroosta/sqlite_modern_cpp/issues/8) // on Github. template< typename Function, typename... Values, std::size_t Boundary = Count > static typename std::enable_if<(sizeof...(Values) < Boundary), void>::type run( row_iterator::value_type& row, Function&& function, Values&&... values ) { typename std::decay>::type value; row >> value; run(row, function, std::forward(values)..., std::move(value)); } template< typename Function, typename... Values, std::size_t Boundary = Count > static typename std::enable_if<(sizeof...(Values) == Boundary), void>::type run( row_iterator::value_type&, Function&& function, Values&&... values ) { function(std::move(values)...); } }; // Some ppl are lazy so we have a operator for proper prep. statemant handling. void inline operator++(database_binder& db, int) { db.execute(); } template database_binder &operator<<(database_binder& db, index_binding_helper val) { db._next_index(); --db._inx; int result = bind_col_in_db(db._stmt.get(), val.index, std::forward(val.value)); if(result != SQLITE_OK) exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); return db; } template database_binder &operator<<(database_binder& db, index_binding_helper val) { db._next_index(); --db._inx; int index = sqlite3_bind_parameter_index(db._stmt.get(), val.index); if(!index) throw errors::unknown_binding("The given binding name is not valid for this statement", db.sql()); int result = bind_col_in_db(db._stmt.get(), index, std::forward(val.value)); if(result != SQLITE_OK) exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); return db; } template database_binder &operator<<(database_binder& db, T&& val) { int result = bind_col_in_db(db._stmt.get(), db._next_index(), std::forward(val)); if(result != SQLITE_OK) exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); return db; } // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) template database_binder operator << (database_binder&& db, const T& val) { db << val; return std::move(db); } template database_binder operator << (database_binder&& db, index_binding_helper val) { db << index_binding_helper{val.index, std::forward(val.value)}; return std::move(db); } namespace sql_function_binder { template struct AggregateCtxt { T obj; bool constructed = true; }; template< typename ContextType, std::size_t Count, typename Functions > inline void step( sqlite3_context* db, int count, sqlite3_value** vals ) { auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); if(!ctxt) return; try { if(!ctxt->constructed) new(ctxt) AggregateCtxt(); step(db, count, vals, ctxt->obj); return; } catch(const sqlite_exception &e) { sqlite3_result_error_code(db, e.get_code()); sqlite3_result_error(db, e.what(), -1); } catch(const std::exception &e) { sqlite3_result_error(db, e.what(), -1); } catch(...) { sqlite3_result_error(db, "Unknown error", -1); } if(ctxt && ctxt->constructed) ctxt->~AggregateCtxt(); } template< std::size_t Count, typename Functions, typename... Values > inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( sqlite3_context* db, int count, sqlite3_value** vals, Values&&... values ) { using arg_type = typename std::remove_cv< typename std::remove_reference< typename utility::function_traits< typename Functions::first_type >::template argument >::type >::type; step( db, count, vals, std::forward(values)..., get_val_from_db(vals[sizeof...(Values) - 1], result_type())); } template< std::size_t Count, typename Functions, typename... Values > inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( sqlite3_context* db, int, sqlite3_value**, Values&&... values ) { static_cast(sqlite3_user_data(db))->first(std::forward(values)...); } template< typename ContextType, typename Functions > inline void final(sqlite3_context* db) { auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); try { if(!ctxt) return; if(!ctxt->constructed) new(ctxt) AggregateCtxt(); store_result_in_db(db, static_cast(sqlite3_user_data(db))->second(ctxt->obj)); } catch(const sqlite_exception &e) { sqlite3_result_error_code(db, e.get_code()); sqlite3_result_error(db, e.what(), -1); } catch(const std::exception &e) { sqlite3_result_error(db, e.what(), -1); } catch(...) { sqlite3_result_error(db, "Unknown error", -1); } if(ctxt && ctxt->constructed) ctxt->~AggregateCtxt(); } template< std::size_t Count, typename Function, typename... Values > inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( sqlite3_context* db, int count, sqlite3_value** vals, Values&&... values ) { using arg_type = typename std::remove_cv< typename std::remove_reference< typename utility::function_traits::template argument >::type >::type; scalar( db, count, vals, std::forward(values)..., get_val_from_db(vals[sizeof...(Values)], result_type())); } template< std::size_t Count, typename Function, typename... Values > inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( sqlite3_context* db, int, sqlite3_value**, Values&&... values ) { try { store_result_in_db(db, (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); } catch(const sqlite_exception &e) { sqlite3_result_error_code(db, e.get_code()); sqlite3_result_error(db, e.what(), -1); } catch(const std::exception &e) { sqlite3_result_error(db, e.what(), -1); } catch(...) { sqlite3_result_error(db, "Unknown error", -1); } } } }