diff --git a/src/storage/backend/sqlite/SQLite.cc b/src/storage/backend/sqlite/SQLite.cc index ddac358994..50bfe52573 100644 --- a/src/storage/backend/sqlite/SQLite.cc +++ b/src/storage/backend/sqlite/SQLite.cc @@ -35,7 +35,7 @@ OperationResult SQLite::DoOpen(RecordValPtr options, OpenResultCallback* cb) { table_name = backend_options->GetField("table_name")->ToStdString(); if ( auto open_res = - checkError(sqlite3_open_v2(full_path.c_str(), &db, + CheckError(sqlite3_open_v2(full_path.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, NULL)); open_res.code != ReturnCode::SUCCESS ) { sqlite3_close_v2(db); @@ -91,7 +91,7 @@ OperationResult SQLite::DoOpen(RecordValPtr options, OpenResultCallback* cb) { for ( const auto& [key, stmt] : statements ) { sqlite3_stmt* ps; - if ( auto prep_res = checkError(sqlite3_prepare_v2(db, stmt.c_str(), stmt.size(), &ps, NULL)); + if ( auto prep_res = CheckError(sqlite3_prepare_v2(db, stmt.c_str(), stmt.size(), &ps, NULL)); prep_res.code != ReturnCode::SUCCESS ) { Close(); return prep_res; @@ -100,6 +100,8 @@ OperationResult SQLite::DoOpen(RecordValPtr options, OpenResultCallback* cb) { prepared_stmts.insert({key, ps}); } + sqlite3_busy_timeout(db, 5000); + return {ReturnCode::SUCCESS}; } @@ -152,40 +154,33 @@ OperationResult SQLite::DoPut(ValPtr key, ValPtr value, bool overwrite, double e stmt = prepared_stmts["put_update"]; auto key_str = json_key->ToStdStringView(); - if ( auto res = checkError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); + if ( auto res = CheckError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { sqlite3_reset(stmt); return res; } auto value_str = json_value->ToStdStringView(); - if ( auto res = checkError(sqlite3_bind_text(stmt, 2, value_str.data(), value_str.size(), SQLITE_STATIC)); + if ( auto res = CheckError(sqlite3_bind_text(stmt, 2, value_str.data(), value_str.size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { sqlite3_reset(stmt); return res; } - if ( auto res = checkError(sqlite3_bind_double(stmt, 3, expiration_time)); res.code != ReturnCode::SUCCESS ) { + if ( auto res = CheckError(sqlite3_bind_double(stmt, 3, expiration_time)); res.code != ReturnCode::SUCCESS ) { sqlite3_reset(stmt); return res; } if ( overwrite ) { - if ( auto res = checkError(sqlite3_bind_text(stmt, 4, value_str.data(), value_str.size(), SQLITE_STATIC)); + if ( auto res = CheckError(sqlite3_bind_text(stmt, 4, value_str.data(), value_str.size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { sqlite3_reset(stmt); return res; } } - if ( auto res = checkError(sqlite3_step(stmt)); res.code != ReturnCode::SUCCESS ) { - sqlite3_reset(stmt); - return res; - } - - sqlite3_reset(stmt); - - return {ReturnCode::SUCCESS}; + return Step(stmt, false); } /** @@ -199,28 +194,13 @@ OperationResult SQLite::DoGet(ValPtr key, OperationResultCallback* cb) { auto stmt = prepared_stmts["get"]; auto key_str = json_key->ToStdStringView(); - if ( auto res = checkError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); + if ( auto res = CheckError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { sqlite3_reset(stmt); return res; } - int errorcode = sqlite3_step(stmt); - if ( errorcode == SQLITE_ROW ) { - // Column 1 is the value - const char* text = (const char*)sqlite3_column_text(stmt, 0); - auto val = zeek::detail::ValFromJSON(text, val_type, Func::nil); - sqlite3_reset(stmt); - if ( std::holds_alternative(val) ) { - ValPtr val_v = std::get(val); - return {ReturnCode::SUCCESS, "", val_v}; - } - else { - return {ReturnCode::OPERATION_FAILED, std::get(val)}; - } - } - - return {ReturnCode::KEY_NOT_FOUND}; + return Step(stmt, true); } /** @@ -234,17 +214,13 @@ OperationResult SQLite::DoErase(ValPtr key, OperationResultCallback* cb) { auto stmt = prepared_stmts["erase"]; auto key_str = json_key->ToStdStringView(); - if ( auto res = checkError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); + if ( auto res = CheckError(sqlite3_bind_text(stmt, 1, key_str.data(), key_str.size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { sqlite3_reset(stmt); return res; } - if ( auto res = checkError(sqlite3_step(stmt)); res.code != ReturnCode::SUCCESS ) { - return res; - } - - return {ReturnCode::SUCCESS}; + return Step(stmt, false); } /** @@ -254,19 +230,17 @@ OperationResult SQLite::DoErase(ValPtr key, OperationResultCallback* cb) { void SQLite::Expire() { auto stmt = prepared_stmts["expire"]; - if ( auto res = checkError(sqlite3_bind_double(stmt, 1, run_state::network_time)); + if ( auto res = CheckError(sqlite3_bind_double(stmt, 1, run_state::network_time)); res.code != ReturnCode::SUCCESS ) { sqlite3_reset(stmt); // TODO: do something with the error here? } - if ( auto res = checkError(sqlite3_step(stmt)); res.code != ReturnCode::SUCCESS ) { - // TODO: do something with the error here? - } + Step(stmt, false); } // returns true in case of error -OperationResult SQLite::checkError(int code) { +OperationResult SQLite::CheckError(int code) { if ( code != SQLITE_OK && code != SQLITE_DONE ) { return {ReturnCode::OPERATION_FAILED, util::fmt("SQLite call failed: %s", sqlite3_errmsg(db)), nullptr}; } @@ -274,4 +248,43 @@ OperationResult SQLite::checkError(int code) { return {ReturnCode::SUCCESS}; } +OperationResult SQLite::Step(sqlite3_stmt* stmt, bool parse_value) { + OperationResult ret; + + int step_status = sqlite3_step(stmt); + if ( step_status == SQLITE_ROW ) { + if ( parse_value ) { + // Column 1 is the value + const char* text = (const char*)sqlite3_column_text(stmt, 0); + auto val = zeek::detail::ValFromJSON(text, val_type, Func::nil); + sqlite3_reset(stmt); + if ( std::holds_alternative(val) ) { + ValPtr val_v = std::get(val); + ret = {ReturnCode::SUCCESS, "", val_v}; + } + else { + ret = {ReturnCode::OPERATION_FAILED, std::get(val)}; + } + } + else { + ret = {ReturnCode::OPERATION_FAILED, "sqlite3_step should not have returned a value"}; + } + } + else if ( step_status == SQLITE_DONE ) { + if ( parse_value ) + ret = {ReturnCode::KEY_NOT_FOUND}; + else + ret = {ReturnCode::SUCCESS}; + } + else if ( step_status == SQLITE_BUSY ) + // TODO: this could retry a number of times instead of just failing + ret = {ReturnCode::TIMEOUT}; + else + ret = {ReturnCode::OPERATION_FAILED}; + + sqlite3_reset(stmt); + + return ret; +} + } // namespace zeek::storage::backend::sqlite diff --git a/src/storage/backend/sqlite/SQLite.h b/src/storage/backend/sqlite/SQLite.h index 17e3e9716d..9f0e613de9 100644 --- a/src/storage/backend/sqlite/SQLite.h +++ b/src/storage/backend/sqlite/SQLite.h @@ -55,7 +55,8 @@ public: void Expire() override; private: - OperationResult checkError(int code); + OperationResult CheckError(int code); + OperationResult Step(sqlite3_stmt* stmt, bool parse_value = false); sqlite3* db = nullptr; std::unordered_map prepared_stmts;