/***************************************************************************** * Media Library ***************************************************************************** * Copyright (C) 2015 Hugo Beauzée-Luyssen, Videolabs * * Authors: Hugo Beauzée-Luyssen * * This program is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation; either version 2.1 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program; if not, write to the Free Software Foundation, * Inc., 51 Franklin Street, Fifth Floor, Boston MA 02110-1301, USA. *****************************************************************************/ #ifndef SQLITETOOLS_H #define SQLITETOOLS_H #include #include #include #include #include #include #include #include #include "Types.h" #include "database/SqliteConnection.h" #include "logging/Logger.h" namespace sqlite { struct ForeignKey { constexpr explicit ForeignKey(unsigned int v) : value(v) {} unsigned int value; }; template using IsSameDecay = std::is_same::type, T>; template struct Traits; template struct Traits::type>::value && ! IsSameDecay::value >::type> { static constexpr int (*Bind)(sqlite3_stmt *, int, int) = &sqlite3_bind_int; static constexpr int (*Load)(sqlite3_stmt *, int) = &sqlite3_column_int; }; template struct Traits::value>::type> { static int Bind( sqlite3_stmt *stmt, int pos, ForeignKey fk) { if ( fk.value != 0 ) return Traits::Bind( stmt, pos, fk.value ); return sqlite3_bind_null( stmt, pos ); } }; template struct Traits::value>::type> { static int Bind(sqlite3_stmt* stmt, int pos, const std::string& value ) { return sqlite3_bind_text( stmt, pos, value.c_str(), -1, SQLITE_STATIC ); } static std::string Load( sqlite3_stmt* stmt, int pos ) { auto tmp = (const char*)sqlite3_column_text( stmt, pos ); if ( tmp != nullptr ) return std::string( tmp ); return std::string(); } }; template struct Traits::type >::value>::type> { static constexpr int (*Bind)(sqlite3_stmt *, int, double) = &sqlite3_bind_double; static constexpr double (*Load)(sqlite3_stmt *, int) = &sqlite3_column_double; }; template <> struct Traits { static int Bind(sqlite3_stmt* stmt, int idx, std::nullptr_t) { return sqlite3_bind_null( stmt, idx ); } }; template struct Traits::type >::value>::type> { using type_t = typename std::underlying_type::type>::type; static int Bind(sqlite3_stmt* stmt, int pos, T value ) { return sqlite3_bind_int( stmt, pos, static_cast( value ) ); } static T Load( sqlite3_stmt* stmt, int pos ) { return static_cast( sqlite3_column_int( stmt, pos ) ); } }; template struct Traits::value>::type> { static constexpr int (*Bind)(sqlite3_stmt *, int, sqlite_int64) = &sqlite3_bind_int64; static constexpr sqlite_int64 (*Load)(sqlite3_stmt *, int) = &sqlite3_column_int64; }; namespace errors { class ConstraintViolation : public std::exception { public: ConstraintViolation( const std::string& req, const std::string& err ) { m_reason = std::string( "Request <" ) + req + "> aborted due to " "constraint violation (" + err + ")"; } virtual const char* what() const noexcept override { return m_reason.c_str(); } private: std::string m_reason; }; class ColumnOutOfRange : public std::exception { public: ColumnOutOfRange( unsigned int idx, unsigned int nbColumns ) { m_reason = "Attempting to extract column at index " + std::to_string( idx ) + " from a request with " + std::to_string( nbColumns ) + "columns"; } virtual const char* what() const noexcept override { return m_reason.c_str(); } private: std::string m_reason; }; } class Row { public: Row( sqlite3_stmt* stmt ) : m_stmt( stmt ) , m_idx( 0 ) , m_nbColumns( sqlite3_column_count( stmt ) ) { } Row() : m_stmt( nullptr ) , m_idx( 0 ) { } /** * @brief operator >> Extracts the next column from this result row. */ template Row& operator>>(T& t) { if ( m_idx + 1 > m_nbColumns ) throw errors::ColumnOutOfRange( m_idx, m_nbColumns ); t = sqlite::Traits::Load( m_stmt, m_idx ); m_idx++; return *this; } /** * @brief Returns the value in column idx, but doesn't advance to the next column */ template T load(unsigned int idx) { if ( m_idx + 1 > m_nbColumns ) throw errors::ColumnOutOfRange( m_idx, m_nbColumns ); return sqlite::Traits::Load( m_stmt, idx ); } bool operator==(std::nullptr_t) { return m_stmt == nullptr; } bool operator!=(std::nullptr_t) { return m_stmt != nullptr; } private: sqlite3_stmt* m_stmt; unsigned int m_idx; unsigned int m_nbColumns; }; class Statement { public: Statement( DBConnection dbConnection, const std::string& req ) : m_stmt( nullptr, &sqlite3_finalize ) , m_dbConn( dbConnection ) , m_req( req ) , m_bindIdx( 0 ) { sqlite3_stmt* stmt; int res = sqlite3_prepare_v2( dbConnection->getConn(), req.c_str(), -1, &stmt, NULL ); if ( res != SQLITE_OK ) { throw std::runtime_error( std::string( "Failed to execute request: " ) + req + " " + sqlite3_errmsg( dbConnection->getConn() ) ); } m_stmt.reset( stmt ); } template void execute(Args&&... args) { m_bindIdx = 1; (void)std::initializer_list{ _bind( std::forward( args ) )... }; } Row row() { auto res = sqlite3_step( m_stmt.get() ); if ( res == SQLITE_ROW ) return Row( m_stmt.get() ); else if ( res == SQLITE_DONE ) return Row(); else { std::string errMsg = sqlite3_errmsg( m_dbConn->getConn() ); switch ( res ) { case SQLITE_CONSTRAINT: throw errors::ConstraintViolation( m_req, errMsg ); default: throw std::runtime_error( errMsg ); } } } private: template bool _bind( T&& value ) { auto res = Traits::Bind( m_stmt.get(), m_bindIdx, std::forward( value ) ); if ( res != SQLITE_OK ) throw std::runtime_error( "Failed to bind parameter" ); m_bindIdx++; return true; } private: std::unique_ptr m_stmt; DBConnection m_dbConn; std::string m_req; unsigned int m_bindIdx; }; class Tools { public: /** * Will fetch all records of type IMPL and return them as a shared_ptr to INTF * This WILL add all fetched records to the cache * * @param results A reference to the result vector. All existing elements will * be discarded. */ template static std::vector > fetchAll( DBConnection dbConnection, const std::string& req, Args&&... args ) { auto ctx = dbConnection->acquireContext(); auto chrono = std::chrono::steady_clock::now(); std::vector> results; auto stmt = Statement( dbConnection, req ); stmt.execute( std::forward( args )... ); Row sqliteRow; while ( ( sqliteRow = stmt.row() ) != nullptr ) { auto row = IMPL::load( dbConnection, sqliteRow ); results.push_back( row ); } auto duration = std::chrono::steady_clock::now() - chrono; LOG_DEBUG("Executed ", req, " in ", std::chrono::duration_cast( duration ).count(), "µs" ); return results; } template static std::shared_ptr fetchOne( DBConnection dbConnection, const std::string& req, Args&&... args ) { auto ctx = dbConnection->acquireContext(); auto chrono = std::chrono::steady_clock::now(); auto stmt = Statement( dbConnection, req ); stmt.execute( std::forward( args )... ); auto row = stmt.row(); if ( row == nullptr ) return nullptr; auto res = T::load( dbConnection, row ); auto duration = std::chrono::steady_clock::now() - chrono; LOG_DEBUG("Executed ", req, " in ", std::chrono::duration_cast( duration ).count(), "µs" ); return res; } template static bool executeRequest( DBConnection dbConnection, const std::string& req, Args&&... args ) { auto ctx = dbConnection->acquireContext(); return executeRequestLocked( dbConnection, req, std::forward( args )... ); } template static bool executeDelete( DBConnection dbConnection, const std::string& req, Args&&... args ) { auto ctx = dbConnection->acquireContext(); if ( executeRequestLocked( dbConnection, req, std::forward( args )... ) == false ) return false; return sqlite3_changes( dbConnection->getConn() ) > 0; } template static bool executeUpdate( DBConnection dbConnectionWeak, const std::string& req, Args&&... args ) { // The code would be exactly the same, do not freak out because it calls executeDelete :) return executeDelete( dbConnectionWeak, req, std::forward( args )... ); } /** * Inserts a record to the DB and return the newly created primary key. * Returns 0 (which is an invalid sqlite primary key) when insertion fails. */ template static unsigned int insert( DBConnection dbConnection, const std::string& req, Args&&... args ) { auto ctx = dbConnection->acquireContext(); if ( executeRequestLocked( dbConnection, req, std::forward( args )... ) == false ) return 0; return sqlite3_last_insert_rowid( dbConnection->getConn() ); } private: template static bool executeRequestLocked( DBConnection dbConnection, const std::string& req, Args&&... args ) { auto chrono = std::chrono::steady_clock::now(); auto stmt = Statement( dbConnection, req ); stmt.execute( std::forward( args )... ); while ( stmt.row() != nullptr ) ; auto duration = std::chrono::steady_clock::now() - chrono; LOG_DEBUG("Executed ", req, " in ", std::chrono::duration_cast( duration ).count(), "µs" ); return true; } // Let SqliteConnection access executeRequestLocked friend SqliteConnection; }; } #endif // SQLITETOOLS_H