diff --git a/src/storage/backend/sqlite/SQLite.cc b/src/storage/backend/sqlite/SQLite.cc index 8c593a69b8..79b9328981 100644 --- a/src/storage/backend/sqlite/SQLite.cc +++ b/src/storage/backend/sqlite/SQLite.cc @@ -45,16 +45,33 @@ ErrorResult SQLite::DoOpen(RecordValPtr options) { create.append("key_str text primary key, value_str text not null);"); char* errorMsg = nullptr; - int res = sqlite3_exec(db, create.c_str(), NULL, NULL, &errorMsg); - if ( res != SQLITE_OK ) { + if ( int res = sqlite3_exec(db, create.c_str(), NULL, NULL, &errorMsg); res != SQLITE_OK ) { std::string err = util::fmt("Error executing table creation statement: %s", errorMsg); Error(err.c_str()); sqlite3_free(errorMsg); - sqlite3_close(db); - db = nullptr; + Close(); return err; } + static std::map statements = + {{"put", util::fmt("insert into %s (key_str, value_str) values(?, ?)", table_name.c_str())}, + {"put_update", util::fmt("insert into %s (key_str, value_str) values(?, ?) ON CONFLICT(key_str) " + "DO UPDATE SET value_str=?", + table_name.c_str())}, + {"get", util::fmt("select value_str from %s where key_str=?", table_name.c_str())}, + {"erase", util::fmt("delete from %s where key_str=?", table_name.c_str())}}; + + for ( const auto& [key, stmt] : statements ) { + sqlite3_stmt* ps; + if ( auto prep_res = checkError(sqlite3_prepare_v2(db, stmt.c_str(), stmt.size(), &ps, NULL)); + prep_res.has_value() ) { + Close(); + return prep_res; + } + + prepared_stmts.insert({key, ps}); + } + return std::nullopt; } @@ -63,6 +80,12 @@ ErrorResult SQLite::DoOpen(RecordValPtr options) { */ void SQLite::Close() { if ( db ) { + for ( const auto& [k, stmt] : prepared_stmts ) { + sqlite3_finalize(stmt); + } + + prepared_stmts.clear(); + if ( int res = sqlite3_close_v2(db); res != SQLITE_OK ) Error("Sqlite could not close connection"); @@ -80,27 +103,41 @@ ErrorResult SQLite::DoPut(ValPtr key, ValPtr value, bool overwrite, double expir auto json_key = key->ToJSON(); auto json_value = value->ToJSON(); - std::string stmt = "INSERT INTO "; - stmt.append(table_name); - stmt.append("(key_str, value_str) VALUES('"); - stmt.append(json_key->ToStdStringView()); - stmt.append("', '"); - stmt.append(json_value->ToStdStringView()); + sqlite3_stmt* stmt; if ( ! overwrite ) - stmt.append("');"); - else { - // if overwriting, add an UPSERT conflict resolution block - stmt.append("') ON CONFLICT(key_str) DO UPDATE SET value_str='"); - stmt.append(json_value->ToStdStringView()); - stmt.append("';"); + stmt = prepared_stmts["put"]; + else + stmt = prepared_stmts["put_update"]; + + auto key_str = json_key->ToStdStringView(); + if ( auto res = checkError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); + res.has_value() ) { + sqlite3_reset(stmt); + return res; } - char* errorMsg = nullptr; - int res = sqlite3_exec(db, stmt.c_str(), NULL, NULL, &errorMsg); - if ( res != SQLITE_OK ) { - return errorMsg; + auto value_str = json_value->ToStdStringView(); + if ( auto res = checkError(sqlite3_bind_text(stmt, 2, value_str.data(), value_str.size(), SQLITE_STATIC)); + res.has_value() ) { + sqlite3_reset(stmt); + return res; } + if ( overwrite ) { + if ( auto res = checkError(sqlite3_bind_text(stmt, 3, value_str.data(), value_str.size(), SQLITE_STATIC)); + res.has_value() ) { + sqlite3_reset(stmt); + return res; + } + } + + if ( auto res = checkError(sqlite3_step(stmt)); res.has_value() ) { + sqlite3_reset(stmt); + return res; + } + + sqlite3_reset(stmt); + return std::nullopt; } @@ -112,22 +149,21 @@ ValResult SQLite::DoGet(ValPtr key, ValResultCallback* cb) { return zeek::unexpected("Database was not open"); auto json_key = key->ToJSON(); + auto stmt = prepared_stmts["get"]; - std::string stmt = "SELECT value_str from " + table_name + " where key_str = '"; - stmt.append(json_key->ToStdStringView()); - stmt.append("';"); + auto key_str = json_key->ToStdStringView(); + if ( auto res = checkError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); + res.has_value() ) { + sqlite3_reset(stmt); + return nonstd::unexpected(res.value()); + } - char* errorMsg = nullptr; - sqlite3_stmt* st; - auto res = checkError(sqlite3_prepare_v2(db, stmt.c_str(), static_cast(stmt.size() + 1), &st, NULL)); - if ( res.has_value() ) - return zeek::unexpected(util::fmt("Failed to prepare select statement: %s", res.value().c_str())); - - int errorcode = sqlite3_step(st); + int errorcode = sqlite3_step(stmt); if ( errorcode == SQLITE_ROW ) { // Column 1 is the value - const char* text = (const char*)sqlite3_column_text(st, 0); + const char* text = (const char*)sqlite3_column_text(stmt, 0); auto val = zeek::detail::ValFromJSON(text, val_type, Func::nil); + sqlite3_reset(stmt); if ( std::holds_alternative(val) ) { ValPtr val_v = std::get(val); return val_v; @@ -148,15 +184,17 @@ ErrorResult SQLite::DoErase(ValPtr key, ErrorResultCallback* cb) { return "Database was not open"; auto json_key = key->ToJSON(); + auto stmt = prepared_stmts["erase"]; - std::string stmt = "DELETE from " + table_name + " where key_str = \'"; - stmt.append(json_key->ToStdStringView()); - stmt.append("\'"); + auto key_str = json_key->ToStdStringView(); + if ( auto res = checkError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); + res.has_value() ) { + sqlite3_reset(stmt); + return res; + } - char* errorMsg = nullptr; - int res = sqlite3_exec(db, stmt.c_str(), NULL, NULL, &errorMsg); - if ( res != SQLITE_OK ) { - return errorMsg; + if ( auto res = checkError(sqlite3_step(stmt)); res.has_value() ) { + return res; } return std::nullopt; @@ -165,9 +203,7 @@ ErrorResult SQLite::DoErase(ValPtr key, ErrorResultCallback* cb) { // returns true in case of error ErrorResult SQLite::checkError(int code) { if ( code != SQLITE_OK && code != SQLITE_DONE ) { - std::string msg = util::fmt("SQLite call failed: %s", sqlite3_errmsg(db)); - Error(msg.c_str()); - return msg; + return util::fmt("SQLite call failed: %s", sqlite3_errmsg(db)); } return std::nullopt; diff --git a/src/storage/backend/sqlite/SQLite.h b/src/storage/backend/sqlite/SQLite.h index 4aca4e1439..7899d7108b 100644 --- a/src/storage/backend/sqlite/SQLite.h +++ b/src/storage/backend/sqlite/SQLite.h @@ -54,6 +54,8 @@ private: ErrorResult checkError(int code); sqlite3* db = nullptr; + std::unordered_map prepared_stmts; + std::string full_path; std::string table_name; }; diff --git a/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-error-handling/.stderr b/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-error-handling/.stderr index 6ed2d3c5d8..4cb5a5b2d9 100644 --- a/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-error-handling/.stderr +++ b/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-error-handling/.stderr @@ -1,3 +1,2 @@ ### BTest baseline data generated by btest-diff. Do not edit. Use "btest -U/-u" to update. Requires BTest >= 0.63. -error in : SQLite call failed: unable to open database file () error in <...>/sqlite-error-handling.zeek, line 20: Failed to open backend SQLITE: SQLite call failed: unable to open database file (Storage::open_backend(Storage::SQLITE, to_any_coerce opts, to_any_coerce str, to_any_coerce str))