diff --git a/src/storage/backend/sqlite/SQLite.cc b/src/storage/backend/sqlite/SQLite.cc index b97ff07331..4deb565fd6 100644 --- a/src/storage/backend/sqlite/SQLite.cc +++ b/src/storage/backend/sqlite/SQLite.cc @@ -17,11 +17,6 @@ using namespace std::chrono_literals; -namespace { -// Helper to check whether a database-stored `expire_time` is considered expired. -bool is_expired(double expire_time) { return expire_time != 0 && expire_time < zeek::run_state::network_time; } -} // namespace - namespace zeek::storage::backend::sqlite { OperationResult SQLite::RunPragma(std::string_view name, std::optional value) { @@ -196,7 +191,10 @@ OperationResult SQLite::DoOpen(OpenResultCallback* cb, RecordValPtr options) { static std::array, 8> statements = {std::make_pair(util::fmt(put_base_cmd.c_str(), table_name.c_str()), db), std::make_pair(util::fmt(put_cmd.c_str(), table_name.c_str()), db), - std::make_pair(util::fmt("select value_str, expire_time from %s where key_str=?", table_name.c_str()), db), + std::make_pair(util::fmt("select value_str, expire_time from %s where key_str=? and ((expire_time > ?) OR " + "(expire_time IS NOT NULL AND expire_time == 0.0))", + table_name.c_str()), + db), std::make_pair(util::fmt("delete from %s where key_str=?", table_name.c_str()), db), std::make_pair( @@ -297,10 +295,10 @@ OperationResult SQLite::DoPut(ResultCallback* cb, ValPtr key, ValPtr value, bool return {ReturnCode::SERIALIZATION_FAILED, "Failed to serialize key"}; unique_stmt_ptr stmt; - if ( ! overwrite ) - stmt = unique_stmt_ptr(put_stmt.get(), sqlite3_reset); - else + if ( overwrite ) stmt = unique_stmt_ptr(put_update_stmt.get(), sqlite3_reset); + else + stmt = unique_stmt_ptr(put_stmt.get(), sqlite3_reset); if ( auto res = CheckError(sqlite3_bind_blob(stmt.get(), 1, key_data->data(), key_data->size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { @@ -331,7 +329,7 @@ OperationResult SQLite::DoPut(ResultCallback* cb, ValPtr key, ValPtr value, bool } if ( ! overwrite ) - if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 6, util::current_time())); + if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 6, zeek::run_state::network_time)); res.code != ReturnCode::SUCCESS ) { return res; } @@ -365,6 +363,11 @@ OperationResult SQLite::DoGet(ResultCallback* cb, ValPtr key) { return res; } + if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 2, zeek::run_state::network_time)); + res.code != ReturnCode::SUCCESS ) { + return res; + } + return Step(stmt.get(), true); } @@ -514,19 +517,15 @@ OperationResult SQLite::Step(sqlite3_stmt* stmt, bool parse_value) { int step_status = sqlite3_step(stmt); if ( step_status == SQLITE_ROW ) { if ( parse_value ) { - if ( sqlite3_column_type(stmt, 1) != SQLITE_NULL && is_expired(sqlite3_column_double(stmt, 1)) ) - ret = {ReturnCode::KEY_NOT_FOUND, ""}; - else { - auto blob = static_cast(sqlite3_column_blob(stmt, 0)); - size_t blob_size = sqlite3_column_bytes(stmt, 0); + auto blob = static_cast(sqlite3_column_blob(stmt, 0)); + size_t blob_size = sqlite3_column_bytes(stmt, 0); - auto val = serializer->Unserialize({blob, blob_size}, val_type); + auto val = serializer->Unserialize({blob, blob_size}, val_type); - if ( val ) - ret = {ReturnCode::SUCCESS, "", val.value()}; - else - ret = {ReturnCode::OPERATION_FAILED, val.error()}; - } + if ( val ) + ret = {ReturnCode::SUCCESS, "", val.value()}; + else + ret = {ReturnCode::OPERATION_FAILED, val.error()}; } else { ret = {ReturnCode::OPERATION_FAILED, "sqlite3_step should not have returned a value"};