diff --git a/src/storage/backend/sqlite/SQLite.cc b/src/storage/backend/sqlite/SQLite.cc index e1d5929d17..18f96870ab 100644 --- a/src/storage/backend/sqlite/SQLite.cc +++ b/src/storage/backend/sqlite/SQLite.cc @@ -78,18 +78,19 @@ OperationResult SQLite::DoOpen(RecordValPtr options, OpenResultCallback* cb) { } } - static std::map statements = - {{"put", util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?)", table_name.c_str())}, - {"put_update", - util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) ON CONFLICT(key_str) " - "DO UPDATE SET value_str=?", - table_name.c_str())}, - {"get", util::fmt("select value_str from %s where key_str=?", table_name.c_str())}, - {"erase", util::fmt("delete from %s where key_str=?", table_name.c_str())}, - {"expire", util::fmt("delete from %s where expire_time > 0 and expire_time != 0 and expire_time <= ?", - table_name.c_str())}}; + static std::array statements = + {util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?)", table_name.c_str()), + util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) ON CONFLICT(key_str) " + "DO UPDATE SET value_str=?", + table_name.c_str()), + util::fmt("select value_str from %s where key_str=?", table_name.c_str()), + util::fmt("delete from %s where key_str=?", table_name.c_str()), + util::fmt("delete from %s where expire_time > 0 and expire_time != 0 and expire_time <= ?", + table_name.c_str())}; - for ( const auto& [key, stmt] : statements ) { + std::array stmt_ptrs; + int i = 0; + for ( const auto& stmt : statements ) { sqlite3_stmt* ps; if ( auto prep_res = CheckError(sqlite3_prepare_v2(db, stmt.c_str(), stmt.size(), &ps, NULL)); prep_res.code != ReturnCode::SUCCESS ) { @@ -97,9 +98,15 @@ OperationResult SQLite::DoOpen(RecordValPtr options, OpenResultCallback* cb) { return prep_res; } - prepared_stmts.insert({key, ps}); + stmt_ptrs[i++] = unique_stmt_ptr(ps, [](sqlite3_stmt* stmt) { sqlite3_finalize(stmt); }); } + put_stmt = std::move(stmt_ptrs[0]); + put_update_stmt = std::move(stmt_ptrs[1]); + get_stmt = std::move(stmt_ptrs[2]); + erase_stmt = std::move(stmt_ptrs[3]); + expire_stmt = std::move(stmt_ptrs[4]); + sqlite3_busy_timeout(db, 5000); return {ReturnCode::SUCCESS}; @@ -112,11 +119,11 @@ OperationResult SQLite::DoClose(OperationResultCallback* cb) { OperationResult op_res{ReturnCode::SUCCESS}; if ( db ) { - for ( const auto& [k, stmt] : prepared_stmts ) { - sqlite3_finalize(stmt); - } - - prepared_stmts.clear(); + put_stmt.reset(); + put_update_stmt.reset(); + get_stmt.reset(); + erase_stmt.reset(); + expire_stmt.reset(); char* errmsg; if ( int res = sqlite3_exec(db, "pragma optimize", NULL, NULL, &errmsg); res != SQLITE_OK ) { @@ -149,9 +156,9 @@ OperationResult SQLite::DoPut(ValPtr key, ValPtr value, bool overwrite, double e sqlite3_stmt* stmt; if ( ! overwrite ) - stmt = prepared_stmts["put"]; + stmt = put_stmt.get(); else - stmt = prepared_stmts["put_update"]; + stmt = put_update_stmt.get(); auto key_str = json_key->ToStdStringView(); if ( auto res = CheckError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); @@ -191,7 +198,7 @@ OperationResult SQLite::DoGet(ValPtr key, OperationResultCallback* cb) { return {ReturnCode::NOT_CONNECTED}; auto json_key = key->ToJSON(); - auto stmt = prepared_stmts["get"]; + auto stmt = get_stmt.get(); auto key_str = json_key->ToStdStringView(); if ( auto res = CheckError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); @@ -211,7 +218,7 @@ OperationResult SQLite::DoErase(ValPtr key, OperationResultCallback* cb) { return {ReturnCode::NOT_CONNECTED}; auto json_key = key->ToJSON(); - auto stmt = prepared_stmts["erase"]; + auto stmt = erase_stmt.get(); auto key_str = json_key->ToStdStringView(); if ( auto res = CheckError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); @@ -228,7 +235,7 @@ OperationResult SQLite::DoErase(ValPtr key, OperationResultCallback* cb) { * derived classes. */ void SQLite::DoExpire() { - auto stmt = prepared_stmts["expire"]; + auto stmt = expire_stmt.get(); if ( auto res = CheckError(sqlite3_bind_double(stmt, 1, run_state::network_time)); res.code != ReturnCode::SUCCESS ) { diff --git a/src/storage/backend/sqlite/SQLite.h b/src/storage/backend/sqlite/SQLite.h index d4929bfa3e..bbef996215 100644 --- a/src/storage/backend/sqlite/SQLite.h +++ b/src/storage/backend/sqlite/SQLite.h @@ -46,7 +46,14 @@ private: OperationResult Step(sqlite3_stmt* stmt, bool parse_value = false); sqlite3* db = nullptr; - std::unordered_map prepared_stmts; + + using stmt_deleter = std::function; + using unique_stmt_ptr = std::unique_ptr; + unique_stmt_ptr put_stmt; + unique_stmt_ptr put_update_stmt; + unique_stmt_ptr get_stmt; + unique_stmt_ptr erase_stmt; + unique_stmt_ptr expire_stmt; std::string full_path; std::string table_name;