Implement standard-library-compatible iterators for Dictionary

This commit is contained in:
Tim Wojtulewicz 2020-09-22 13:59:13 -07:00
parent 9e9998c6e5
commit 892124378c
16 changed files with 834 additions and 254 deletions

View file

@ -24,7 +24,7 @@
namespace zeek {
class IterCookie {
class [[deprecated("Remove in v5.1. Use the standard-library-compatible version of iteration.")]] IterCookie {
public:
IterCookie(Dictionary* d) : d(d) {}
@ -184,6 +184,8 @@ TEST_CASE("dict iteration")
dict.Insert(key, &val);
dict.Insert(key2, &val2);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
detail::HashKey* it_key;
IterCookie* it = dict.InitForIteration();
CHECK(it != nullptr);
@ -191,32 +193,269 @@ TEST_CASE("dict iteration")
while ( uint32_t* entry = dict.NextEntry(it_key, it) )
{
if ( count == 0 )
switch ( count )
{
// The DictEntry constructor typecasts this down to a uint32_t, so
// we can't just check the value directly.
// Explanation: hash_t is 64bit, open-dict only uses 32bit hash to
// save space for each item (24 bytes aligned). OpenDict has table
// size of 2^N and only take the lower bits of the hash. (The
// original hash takes transformation in FibHash() to map into a
// smaller 2^N range).
CHECK(it_key->Hash() == (uint32_t)key2->Hash());
CHECK(*entry == 10);
}
else
{
CHECK(it_key->Hash() == (uint32_t)key->Hash());
CHECK(*entry == 15);
case 0:
// The DictEntry constructor typecasts this down to a uint32_t, so
// we can't just check the value directly.
// Explanation: hash_t is 64bit, open-dict only uses 32bit hash to
// save space for each item (24 bytes aligned). OpenDict has table
// size of 2^N and only take the lower bits of the hash. (The
// original hash takes transformation in FibHash() to map into a
// smaller 2^N range).
CHECK(it_key->Hash() == (uint32_t)key2->Hash());
CHECK(*entry == 10);
break;
case 1:
CHECK(it_key->Hash() == (uint32_t)key->Hash());
CHECK(*entry == 15);
break;
default:
break;
}
count++;
delete it_key;
}
CHECK(count == 2);
#pragma GCC diagnostic pop
delete key;
delete key2;
}
TEST_CASE("dict new iteration")
{
PDict<uint32_t> dict;
uint32_t val = 15;
uint32_t key_val = 5;
detail::HashKey* key = new detail::HashKey(key_val);
uint32_t val2 = 10;
uint32_t key_val2 = 25;
detail::HashKey* key2 = new detail::HashKey(key_val2);
dict.Insert(key, &val);
dict.Insert(key2, &val2);
int count = 0;
for ( const auto& entry : dict )
{
auto* v = static_cast<uint32_t*>(entry.value);
uint64_t k = *(uint32_t*) entry.GetKey();
switch ( count )
{
case 0:
CHECK(k == key_val2);
CHECK(*v == val2);
break;
case 1:
CHECK(k == key_val);
CHECK(*v == val);
break;
default:
break;
}
count++;
}
CHECK(count == 2);
delete key;
delete key2;
}
TEST_CASE("dict robust iteration")
{
PDict<uint32_t> dict;
uint32_t val = 15;
uint32_t key_val = 5;
detail::HashKey* key = new detail::HashKey(key_val);
uint32_t val2 = 10;
uint32_t key_val2 = 25;
detail::HashKey* key2 = new detail::HashKey(key_val2);
uint32_t val3 = 20;
uint32_t key_val3 = 35;
detail::HashKey* key3 = new detail::HashKey(key_val3);
dict.Insert(key, &val);
dict.Insert(key2, &val2);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
detail::HashKey* it_key;
IterCookie* it = dict.InitForIteration();
CHECK(it != nullptr);
dict.MakeRobustCookie(it);
int count = 0;
while ( uint32_t* entry = dict.NextEntry(it_key, it) )
{
switch ( count )
{
case 0:
CHECK(it_key->Hash() == (uint32_t)key2->Hash());
CHECK(*entry == 10);
dict.Insert(key3, &val3);
break;
case 1:
CHECK(it_key->Hash() == (uint32_t)key->Hash());
CHECK(*entry == 15);
break;
case 2:
CHECK(it_key->Hash() == (uint32_t)key3->Hash());
CHECK(*entry == 20);
break;
default:
// We shouldn't get here.
CHECK(false);
break;
}
count++;
delete it_key;
}
CHECK(count == 3);
IterCookie* it2 = dict.InitForIteration();
CHECK(it2 != nullptr);
dict.MakeRobustCookie(it2);
count = 0;
while ( uint32_t* entry = dict.NextEntry(it_key, it2) )
{
switch ( count )
{
case 0:
CHECK(it_key->Hash() == (uint32_t)key2->Hash());
CHECK(*entry == 10);
dict.Remove(key3);
break;
case 1:
CHECK(it_key->Hash() == (uint32_t)key->Hash());
CHECK(*entry == 15);
break;
default:
// We shouldn't get here.
CHECK(false);
break;
}
count++;
delete it_key;
}
CHECK(count == 2);
#pragma GCC diagnostic pop
delete key;
delete key2;
delete key3;
}
TEST_CASE("dict new robust iteration")
{
PDict<uint32_t> dict;
uint32_t val = 15;
uint32_t key_val = 5;
detail::HashKey* key = new detail::HashKey(key_val);
uint32_t val2 = 10;
uint32_t key_val2 = 25;
detail::HashKey* key2 = new detail::HashKey(key_val2);
uint32_t val3 = 20;
uint32_t key_val3 = 35;
detail::HashKey* key3 = new detail::HashKey(key_val3);
dict.Insert(key, &val);
dict.Insert(key2, &val2);
{
int count = 0;
auto it = dict.begin_robust();
for ( ; it != dict.end_robust(); ++it )
{
auto* v = it->GetValue<uint32_t*>();
uint64_t k = *(uint32_t*) it->GetKey();
switch ( count )
{
case 0:
CHECK(k == key_val2);
CHECK(*v == val2);
dict.Insert(key3, &val3);
break;
case 1:
CHECK(k == key_val);
CHECK(*v == val);
break;
case 2:
CHECK(k == key_val3);
CHECK(*v == val3);
break;
default:
// We shouldn't get here.
CHECK(false);
break;
}
count++;
}
CHECK(count == 3);
}
{
int count = 0;
auto it = dict.begin_robust();
for ( ; it != dict.end_robust(); ++it )
{
auto* v = it->GetValue<uint32_t*>();
uint64_t k = *(uint32_t*) it->GetKey();
switch ( count )
{
case 0:
CHECK(k == key_val2);
CHECK(*v == val2);
dict.Insert(key3, &val3);
dict.Remove(key3);
break;
case 1:
CHECK(k == key_val);
CHECK(*v == val);
break;
default:
// We shouldn't get here.
CHECK(false);
break;
}
count++;
}
CHECK(count == 2);
}
delete key;
delete key2;
delete key3;
}
TEST_CASE("dict iterator invalidation")
{
PDict<uint32_t> dict;
@ -238,6 +477,8 @@ TEST_CASE("dict iterator invalidation")
detail::HashKey* it_key;
bool iterators_invalidated = false;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
IterCookie* it = dict.InitForIteration();
CHECK(it != nullptr);
@ -277,6 +518,7 @@ TEST_CASE("dict iterator invalidation")
dict.StopIteration(it);
break;
}
#pragma GCC diagnostic pop
CHECK(dict.Length() == 2);
CHECK(*static_cast<uint32_t*>(dict.Lookup(key)) == val2);
@ -415,7 +657,7 @@ int Dictionary::OffsetInClusterByPosition(int position) const
return position - head;
}
// Find the next valid entry after the position. Positiion can be -1, which means
// Find the next valid entry after the position. Position can be -1, which means
// look for the next valid entry point altogether.
int Dictionary::Next(int position) const
{
@ -613,7 +855,7 @@ void Dictionary::Dump(int level) const
}
}
//////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
//Initialization.
////////////////////////////////////////////////////////////////////////////////////////////////////
Dictionary::Dictionary(DictOrder ordering, int initial_size)
@ -661,6 +903,11 @@ void Dictionary::Clear()
delete cookies;
cookies = nullptr;
}
if ( iterators )
{
delete iterators;
iterators = nullptr;
}
log2_buckets = 0;
num_iterators = 0;
remaps = 0;
@ -787,7 +1034,7 @@ int Dictionary::LookupIndex(const void* key, int key_size, detail::hash_t hash,
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Insert
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void* Dictionary::Insert(void* key, int key_size, detail::hash_t hash, void* val, bool copy_key, bool* iterators_invalidated)
{
@ -827,6 +1074,15 @@ void* Dictionary::Insert(void* key, int key_size, detail::hash_t hash, void* val
if ( it != c->inserted->end() )
it->value = val;
}
if ( iterators && ! iterators->empty() )
//need to set new v for iterators too.
for ( auto c: *iterators )
{
auto it = std::find(c->inserted->begin(), c->inserted->end(), table[position]);
if ( it != c->inserted->end() )
it->value = val;
}
}
else
{
@ -879,6 +1135,12 @@ void Dictionary::InsertRelocateAndAdjust(detail::DictEntry& entry, int insert_po
if ( cookies && ! cookies->empty() )
for ( auto c: *cookies )
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
AdjustOnInsert(c, entry, insert_position, last_affected_position);
#pragma GCC diagnostic pop
if ( iterators && ! iterators->empty() )
for ( auto c: *iterators )
AdjustOnInsert(c, entry, insert_position, last_affected_position);
}
@ -916,6 +1178,9 @@ void Dictionary::InsertAndRelocate(detail::DictEntry& entry, int insert_position
}
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
/// Adjust Cookies on Insert.
void Dictionary::AdjustOnInsert(IterCookie* c, const detail::DictEntry& entry, int insert_position, int last_affected_position)
{
@ -931,6 +1196,21 @@ void Dictionary::AdjustOnInsert(IterCookie* c, const detail::DictEntry& entry, i
}
}
#pragma GCC diagnostic pop
void Dictionary::AdjustOnInsert(RobustDictIterator* c, const detail::DictEntry& entry,
int insert_position, int last_affected_position)
{
if ( insert_position < c->next )
c->inserted->push_back(entry);
if ( insert_position < c->next && c->next <= last_affected_position )
{
int k = TailOfClusterByPosition(c->next);
ASSERT(k >= 0 && k < Capacity());
c->visited->push_back(table[k]);
}
}
void Dictionary::SizeUp()
{
int prev_capacity = Capacity();
@ -953,7 +1233,7 @@ void Dictionary::SizeUp()
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Remove
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void* Dictionary::Remove(const void* key, int key_size, detail::hash_t hash, bool dont_delete, bool* iterators_invalidated)
{//cookie adjustment: maintain inserts here. maintain next in lower level version.
@ -999,6 +1279,13 @@ detail::DictEntry Dictionary::RemoveRelocateAndAdjust(int position)
if ( cookies && ! cookies->empty() )
for ( auto c: *cookies )
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
AdjustOnRemove(c, entry, position, last_affected_position);
#pragma GCC diagnostic pop
if ( iterators && ! iterators->empty() )
for ( auto c: *iterators )
AdjustOnRemove(c, entry, position, last_affected_position);
return entry;
@ -1029,6 +1316,9 @@ detail::DictEntry Dictionary::RemoveAndRelocate(int position, int* last_affected
return entry;
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
void Dictionary::AdjustOnRemove(IterCookie* c, const detail::DictEntry& entry, int position, int last_affected_position)
{
ASSERT_VALID(c);
@ -1046,6 +1336,25 @@ void Dictionary::AdjustOnRemove(IterCookie* c, const detail::DictEntry& entry, i
c->next = Next(c->next);
}
#pragma GCC diagnostic pop
void Dictionary::AdjustOnRemove(RobustDictIterator* c, const detail::DictEntry& entry,
int position, int last_affected_position)
{
c->inserted->erase(std::remove(c->inserted->begin(), c->inserted->end(), entry), c->inserted->end());
if ( position < c->next && c->next <= last_affected_position )
{
int moved = HeadOfClusterByPosition(c->next-1);
if ( moved < position )
moved = position;
c->inserted->push_back(table[moved]);
}
//if not already the end of the dictionary, adjust next to a valid one.
if ( c->next < Capacity() && table[c->next].Empty() )
c->next = Next(c->next);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
//Remap
///////////////////////////////////////////////////////////////////////////////////////////////////
@ -1057,7 +1366,7 @@ void Dictionary::Remap()
///remap from bottom up.
///remap creates two parts of the dict: [0,remap_end] (remap_end, ...]. the former is mixed with old/new entries; the latter contains all new entries.
///
if ( num_iterators )
if ( num_iterators > 0 )
return;
int left = detail::DICT_REMAP_ENTRIES;
@ -1076,7 +1385,7 @@ bool Dictionary::Remap(int position, int* new_position)
{
ASSERT_VALID(this);
///Remap changes item positions by remove() and insert(). to avoid excessive operation. avoid it when safe iteration is in progress.
ASSERT(! cookies || cookies->empty());
ASSERT( ( ! cookies || cookies->empty() ) && ( ! iterators || iterators->empty() ) );
int current = BucketByPosition(position);//current bucket
int expected = BucketByHash(table[position].hash, log2_buckets); //expected bucket in new table.
//equal because 1: it's a new item, 2: it's an old item, but new bucket is the same as old. 50% of old items act this way due to fibhash.
@ -1097,10 +1406,6 @@ bool Dictionary::Remap(int position, int* new_position)
return true;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Iteration
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void* Dictionary::NthEntry(int n, const void*& key, int& key_size) const
{
if ( ! order || n < 0 || n >= Length() )
@ -1111,6 +1416,13 @@ void* Dictionary::NthEntry(int n, const void*& key, int& key_size) const
return entry.value;
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Iteration
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
void Dictionary::MakeRobustCookie(IterCookie* cookie)
{ //make sure c->next >= 0.
if ( ! cookies )
@ -1164,21 +1476,10 @@ void* Dictionary::NextEntryNonConst(detail::HashKey*& h, IterCookie*& c, bool re
if ( c->next < 0 )
c->next = Next(-1);
// if resize happens during iteration. before sizeup, c->next points to Capacity(),
// but now Capacity() doubles up and c->next doesn't point to the end anymore.
// this is fine because c->next may be filled now.
// however, c->next can also be empty.
// before sizeup, we use c->next >= Capacity() to indicate the end of the iteration.
// now this guard is invalid, we may face c->next is valid but empty now.F
//fix it here.
int capacity = Capacity();
if ( c->next < capacity && table[c->next].Empty() )
{
ASSERT(false); //stop to check the condition here. why it's happening.
c->next = Next(c->next);
}
ASSERT(c->next >= Capacity() || ! table[c->next].Empty());
//filter out visited keys.
int capacity = Capacity();
if ( c->visited && ! c->visited->empty() )
//filter out visited entries.
while ( c->next < capacity )
@ -1229,4 +1530,193 @@ void Dictionary::StopIteration(IterCookie* cookie) const
dp->StopIterationNonConst(cookie);
}
#pragma GCC diagnostic pop
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// New Iteration
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DictIterator::DictIterator(const Dictionary* d, detail::DictEntry* begin, detail::DictEntry* end)
: curr(begin), end(end)
{
// Make sure that we're starting on a non-empty element.
while ( curr != end && curr->Empty() )
++curr;
// Cast away the constness so that the number of iterators can be modified in the dictionary. This does
// violate the constness guarantees of const-begin()/end() and cbegin()/cend(), but we're not modifying the
// actual data in the collection, just a counter in the wrapper of the collection.
dict = const_cast<Dictionary*>(d);
dict->num_iterators++;
}
DictIterator::~DictIterator()
{
assert(dict->num_iterators > 0);
dict->num_iterators--;
}
DictIterator& DictIterator::operator++()
{
// The non-robust case is easy. Just advanced the current position forward until you find
// one isn't empty and isn't the end.
do {
++curr;
}
while ( curr != end && curr->Empty() );
return *this;
}
RobustDictIterator Dictionary::MakeRobustIterator()
{
if ( ! iterators )
iterators = new std::vector<RobustDictIterator*>;
return { this };
}
detail::DictEntry Dictionary::GetNextRobustIteration(RobustDictIterator* iter)
{
// If there are any inserted entries, return them first.
// That keeps the list small and helps avoiding searching
// a large list when deleting an entry.
if ( ! table )
{
iter->Complete();
return detail::DictEntry(nullptr); // end of iteration
}
if ( iter->inserted && ! iter->inserted->empty() )
{
// Return the last one. Order doesn't matter,
// and removing from the tail is cheaper.
detail::DictEntry e = iter->inserted->back();
iter->inserted->pop_back();
return e;
}
if ( iter->next < 0 )
iter->next = Next(-1);
ASSERT(iter->next >= Capacity() || ! table[iter->next].Empty());
// Filter out visited keys.
int capacity = Capacity();
if ( iter->visited && ! iter->visited->empty() )
// Filter out visited entries.
while ( iter->next < capacity )
{
ASSERT(! table[iter->next].Empty());
auto it = std::find(iter->visited->begin(), iter->visited->end(), table[iter->next]);
if ( it == iter->visited->end() )
break;
iter->visited->erase(it);
iter->next = Next(iter->next);
}
if ( iter->next >= capacity )
{
iter->Complete();
return detail::DictEntry(nullptr); // end of iteration
}
ASSERT(! table[iter->next].Empty());
detail::DictEntry e = table[iter->next];
//prepare for next time.
iter->next = Next(iter->next);
return e;
}
RobustDictIterator::RobustDictIterator(Dictionary* d) : curr(nullptr), dict(d)
{
next = -1;
inserted = new std::vector<detail::DictEntry>();
visited = new std::vector<detail::DictEntry>();
dict->num_iterators++;
dict->iterators->push_back(this);
// Advance the iterator one step so that we're at the first element.
curr = dict->GetNextRobustIteration(this);
}
RobustDictIterator::RobustDictIterator(const RobustDictIterator& other) : curr(nullptr)
{
dict = nullptr;
if ( other.dict )
{
next = other.next;
inserted = new std::vector<detail::DictEntry>();
visited = new std::vector<detail::DictEntry>();
if ( other.inserted )
std::copy(other.inserted->begin(), other.inserted->end(), std::back_inserter(*inserted));
if ( other.visited)
std::copy(other.visited->begin(), other.visited->end(), std::back_inserter(*visited));
dict = other.dict;
dict->num_iterators++;
dict->iterators->push_back(this);
curr = other.curr;
}
}
RobustDictIterator::RobustDictIterator(RobustDictIterator&& other) : curr(nullptr)
{
dict = nullptr;
if ( other.dict )
{
next = other.next;
inserted = other.inserted;
visited = other.visited;
dict = other.dict;
dict->iterators->push_back(this);
dict->iterators->erase(std::remove(dict->iterators->begin(), dict->iterators->end(), &other),
dict->iterators->end());
other.dict = nullptr;
curr = std::move(other.curr);
}
}
RobustDictIterator::~RobustDictIterator()
{
Complete();
}
void RobustDictIterator::Complete()
{
if ( dict )
{
assert(dict->num_iterators > 0);
dict->num_iterators--;
dict->iterators->erase(std::remove(dict->iterators->begin(), dict->iterators->end(), this),
dict->iterators->end());
delete inserted;
delete visited;
inserted = nullptr;
visited = nullptr;
dict = nullptr;
}
}
RobustDictIterator& RobustDictIterator::operator++()
{
curr = dict->GetNextRobustIteration(this);
return *this;
}
} // namespace zeek