diff --git a/src/storage/backend/sqlite/SQLite.cc b/src/storage/backend/sqlite/SQLite.cc index 4d7447b77a..7250e0328c 100644 --- a/src/storage/backend/sqlite/SQLite.cc +++ b/src/storage/backend/sqlite/SQLite.cc @@ -182,27 +182,38 @@ OperationResult SQLite::DoOpen(OpenResultCallback* cb, RecordValPtr options) { sqlite3_free(errorMsg); static std::array, 8> statements = - {std::make_pair(util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?)", + {// Normal put + std::make_pair(util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) " + "ON CONFLICT(key_str) DO UPDATE SET value_str=?, expire_time=? " + "WHERE expire_time > 0.0 AND expire_time < ?", table_name.c_str()), db), - std::make_pair(util:: - fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) ON CONFLICT(key_str) " - "DO UPDATE SET value_str=?, expire_time=?", - table_name.c_str()), + // Put with forced overwrite + std::make_pair(util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) " + "ON CONFLICT(key_str) DO UPDATE SET value_str=?, expire_time=?", + table_name.c_str()), db), - std::make_pair(util::fmt("select value_str from %s where key_str=?", table_name.c_str()), db), + // Get + std::make_pair(util::fmt("select value_str, expire_time from %s where key_str=? and " + "(expire_time > ? OR expire_time == 0.0)", + table_name.c_str()), + db), + // Erase std::make_pair(util::fmt("delete from %s where key_str=?", table_name.c_str()), db), - + // Check for expired entries std::make_pair( util::fmt("select count(*) from %s where expire_time > 0 and expire_time != 0 and expire_time <= ?", table_name.c_str()), expire_db), + // Remove expired entries std::make_pair(util::fmt("delete from %s where expire_time > 0 and expire_time != 0 and expire_time <= ?", table_name.c_str()), expire_db), + // Get the last time expiry ran std::make_pair(util::fmt("select last_run from zeek_storage_expiry_runs where ukey = '%s'", table_name.c_str()), expire_db), + // Update the last time expiry ran std::make_pair(util::fmt("update zeek_storage_expiry_runs set last_run = ? where ukey = '%s'", table_name.c_str()), expire_db)}; @@ -291,10 +302,10 @@ OperationResult SQLite::DoPut(ResultCallback* cb, ValPtr key, ValPtr value, bool return {ReturnCode::SERIALIZATION_FAILED, "Failed to serialize key"}; unique_stmt_ptr stmt; - if ( ! overwrite ) - stmt = unique_stmt_ptr(put_stmt.get(), sqlite3_reset); - else + if ( overwrite ) stmt = unique_stmt_ptr(put_update_stmt.get(), sqlite3_reset); + else + stmt = unique_stmt_ptr(put_stmt.get(), sqlite3_reset); if ( auto res = CheckError(sqlite3_bind_blob(stmt.get(), 1, key_data->data(), key_data->size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { @@ -314,20 +325,31 @@ OperationResult SQLite::DoPut(ResultCallback* cb, ValPtr key, ValPtr value, bool return res; } - if ( overwrite ) { - if ( auto res = CheckError(sqlite3_bind_blob(stmt.get(), 4, val_data->data(), val_data->size(), SQLITE_STATIC)); - res.code != ReturnCode::SUCCESS ) { - return res; - } - - // This duplicates the above binding, but it's to overwrite the expiration time on the entry. - if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 5, expiration_time)); - res.code != ReturnCode::SUCCESS ) { - return res; - } + if ( auto res = CheckError(sqlite3_bind_blob(stmt.get(), 4, val_data->data(), val_data->size(), SQLITE_STATIC)); + res.code != ReturnCode::SUCCESS ) { + return res; } - return Step(stmt.get(), false); + // This duplicates the above binding, but it's to overwrite the expiration time on the entry. + if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 5, expiration_time)); res.code != ReturnCode::SUCCESS ) { + return res; + } + + if ( ! overwrite ) + if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 6, run_state::network_time)); + res.code != ReturnCode::SUCCESS ) { + return res; + } + + auto step_result = Step(stmt.get(), false); + if ( ! overwrite ) + if ( step_result.code == ReturnCode::SUCCESS ) { + int changed = sqlite3_changes(db); + if ( changed == 0 ) + step_result.code = ReturnCode::KEY_EXISTS; + } + + return step_result; } /** @@ -348,6 +370,11 @@ OperationResult SQLite::DoGet(ResultCallback* cb, ValPtr key) { return res; } + if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 2, run_state::network_time)); + res.code != ReturnCode::SUCCESS ) { + return res; + } + return Step(stmt.get(), true); } diff --git a/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-expiration-implicit/output b/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-expiration-implicit/output new file mode 100644 index 0000000000..829eded2c7 --- /dev/null +++ b/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-expiration-implicit/output @@ -0,0 +1,4 @@ +### BTest baseline data generated by btest-diff. Do not edit. Use "btest -U/-u" to update. Requires BTest >= 0.63. +BEFORE, [code=Storage::SUCCESS, error_str=, value=v] +AFTER, [code=Storage::KEY_NOT_FOUND, error_str=, value=] +OVERWRITE, [code=Storage::SUCCESS, error_str=, value=vv] diff --git a/testing/btest/scripts/base/frameworks/storage/sqlite-expiration-implicit.zeek b/testing/btest/scripts/base/frameworks/storage/sqlite-expiration-implicit.zeek new file mode 100644 index 0000000000..f6197ea9ad --- /dev/null +++ b/testing/btest/scripts/base/frameworks/storage/sqlite-expiration-implicit.zeek @@ -0,0 +1,61 @@ +# @TEST-EXEC: zeek %INPUT 2>&1 >output +# @TEST-EXEC: btest-diff output + +@load base/frameworks/storage/sync +@load policy/frameworks/storage/backend/sqlite + +# Manually control the clock. +redef allow_network_time_forward = F; + +redef Storage::expire_interval = 1secs; + +event zeek_init() + { + local opts = Storage::BackendOptions( + $serializer=Storage::STORAGE_SERIALIZER_JSON, + $sqlite=Storage::Backend::SQLite::Options( + $database_path="test.sqlite", + $table_name="testing")); + + local open_res = Storage::Sync::open_backend( + Storage::STORAGE_BACKEND_SQLITE, + opts, + string, + string); + local h = open_res$value; + + local key="k"; + local value="v"; + + # Expire entries well within `Storage::expire_interval` so `DoExpire` does not kick in. + local expire = Storage::expire_interval / 10; + local expire_time = network_time() + expire; + Storage::Sync::put( + h, + Storage::PutArgs( + $key=key, + $value=value, + $expire_time=expire)); + + # The entry we just put in exists. + local get = Storage::Sync::get(h, key); + print "BEFORE", get; + + # Advance time. + set_network_time(network_time() + expire * 2); + + # An expired value does not exist. + get = Storage::Sync::get(h, key); + print "AFTER", get; + + # Even though the entry still exists in the backend we can put a + # new value in its place without specifying overwrite. + Storage::Sync::put( + h, + Storage::PutArgs( + $key=key, + $value=value+value)); + + get = Storage::Sync::get(h, key); + print "OVERWRITE", get; + }