diff --git a/src/storage/backend/sqlite/SQLite.cc b/src/storage/backend/sqlite/SQLite.cc index bb3f4b6f8a..7250e0328c 100644 --- a/src/storage/backend/sqlite/SQLite.cc +++ b/src/storage/backend/sqlite/SQLite.cc @@ -182,30 +182,38 @@ OperationResult SQLite::DoOpen(OpenResultCallback* cb, RecordValPtr options) { sqlite3_free(errorMsg); static std::array, 8> statements = - {std::make_pair(util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?)", + {// Normal put + std::make_pair(util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) " + "ON CONFLICT(key_str) DO UPDATE SET value_str=?, expire_time=? " + "WHERE expire_time > 0.0 AND expire_time < ?", table_name.c_str()), db), - std::make_pair(util:: - fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) ON CONFLICT(key_str) " - "DO UPDATE SET value_str=?, expire_time=?", - table_name.c_str()), + // Put with forced overwrite + std::make_pair(util::fmt("insert into %s (key_str, value_str, expire_time) values(?, ?, ?) " + "ON CONFLICT(key_str) DO UPDATE SET value_str=?, expire_time=?", + table_name.c_str()), db), + // Get std::make_pair(util::fmt("select value_str, expire_time from %s where key_str=? and " "(expire_time > ? OR expire_time == 0.0)", table_name.c_str()), db), + // Erase std::make_pair(util::fmt("delete from %s where key_str=?", table_name.c_str()), db), - + // Check for expired entries std::make_pair( util::fmt("select count(*) from %s where expire_time > 0 and expire_time != 0 and expire_time <= ?", table_name.c_str()), expire_db), + // Remove expired entries std::make_pair(util::fmt("delete from %s where expire_time > 0 and expire_time != 0 and expire_time <= ?", table_name.c_str()), expire_db), + // Get the last time expiry ran std::make_pair(util::fmt("select last_run from zeek_storage_expiry_runs where ukey = '%s'", table_name.c_str()), expire_db), + // Update the last time expiry ran std::make_pair(util::fmt("update zeek_storage_expiry_runs set last_run = ? where ukey = '%s'", table_name.c_str()), expire_db)}; @@ -294,10 +302,10 @@ OperationResult SQLite::DoPut(ResultCallback* cb, ValPtr key, ValPtr value, bool return {ReturnCode::SERIALIZATION_FAILED, "Failed to serialize key"}; unique_stmt_ptr stmt; - if ( ! overwrite ) - stmt = unique_stmt_ptr(put_stmt.get(), sqlite3_reset); - else + if ( overwrite ) stmt = unique_stmt_ptr(put_update_stmt.get(), sqlite3_reset); + else + stmt = unique_stmt_ptr(put_stmt.get(), sqlite3_reset); if ( auto res = CheckError(sqlite3_bind_blob(stmt.get(), 1, key_data->data(), key_data->size(), SQLITE_STATIC)); res.code != ReturnCode::SUCCESS ) { @@ -317,20 +325,31 @@ OperationResult SQLite::DoPut(ResultCallback* cb, ValPtr key, ValPtr value, bool return res; } - if ( overwrite ) { - if ( auto res = CheckError(sqlite3_bind_blob(stmt.get(), 4, val_data->data(), val_data->size(), SQLITE_STATIC)); - res.code != ReturnCode::SUCCESS ) { - return res; - } - - // This duplicates the above binding, but it's to overwrite the expiration time on the entry. - if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 5, expiration_time)); - res.code != ReturnCode::SUCCESS ) { - return res; - } + if ( auto res = CheckError(sqlite3_bind_blob(stmt.get(), 4, val_data->data(), val_data->size(), SQLITE_STATIC)); + res.code != ReturnCode::SUCCESS ) { + return res; } - return Step(stmt.get(), false); + // This duplicates the above binding, but it's to overwrite the expiration time on the entry. + if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 5, expiration_time)); res.code != ReturnCode::SUCCESS ) { + return res; + } + + if ( ! overwrite ) + if ( auto res = CheckError(sqlite3_bind_double(stmt.get(), 6, run_state::network_time)); + res.code != ReturnCode::SUCCESS ) { + return res; + } + + auto step_result = Step(stmt.get(), false); + if ( ! overwrite ) + if ( step_result.code == ReturnCode::SUCCESS ) { + int changed = sqlite3_changes(db); + if ( changed == 0 ) + step_result.code = ReturnCode::KEY_EXISTS; + } + + return step_result; } /** diff --git a/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-expiration-implicit/output b/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-expiration-implicit/output index 0730873d8c..829eded2c7 100644 --- a/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-expiration-implicit/output +++ b/testing/btest/Baseline/scripts.base.frameworks.storage.sqlite-expiration-implicit/output @@ -1,3 +1,4 @@ ### BTest baseline data generated by btest-diff. Do not edit. Use "btest -U/-u" to update. Requires BTest >= 0.63. BEFORE, [code=Storage::SUCCESS, error_str=, value=v] AFTER, [code=Storage::KEY_NOT_FOUND, error_str=, value=] +OVERWRITE, [code=Storage::SUCCESS, error_str=, value=vv] diff --git a/testing/btest/scripts/base/frameworks/storage/sqlite-expiration-implicit.zeek b/testing/btest/scripts/base/frameworks/storage/sqlite-expiration-implicit.zeek index aa9e4d9a98..f6197ea9ad 100644 --- a/testing/btest/scripts/base/frameworks/storage/sqlite-expiration-implicit.zeek +++ b/testing/btest/scripts/base/frameworks/storage/sqlite-expiration-implicit.zeek @@ -47,4 +47,15 @@ event zeek_init() # An expired value does not exist. get = Storage::Sync::get(h, key); print "AFTER", get; + + # Even though the entry still exists in the backend we can put a + # new value in its place without specifying overwrite. + Storage::Sync::put( + h, + Storage::PutArgs( + $key=key, + $value=value+value)); + + get = Storage::Sync::get(h, key); + print "OVERWRITE", get; }