diff --git a/src/packet_analysis/CMakeLists.txt b/src/packet_analysis/CMakeLists.txt index 88023b8fa2..044370b727 100644 --- a/src/packet_analysis/CMakeLists.txt +++ b/src/packet_analysis/CMakeLists.txt @@ -6,10 +6,10 @@ include_directories(BEFORE ) add_subdirectory(protocol) -add_subdirectory(dispatchers) set(llanalyzer_SRCS Analyzer.cc + Dispatcher.cc Manager.cc Component.cc Tag.cc diff --git a/src/packet_analysis/dispatchers/VectorDispatcher.cc b/src/packet_analysis/Dispatcher.cc similarity index 86% rename from src/packet_analysis/dispatchers/VectorDispatcher.cc rename to src/packet_analysis/Dispatcher.cc index 09256f0a85..7662c42b88 100644 --- a/src/packet_analysis/dispatchers/VectorDispatcher.cc +++ b/src/packet_analysis/Dispatcher.cc @@ -2,16 +2,16 @@ #include -#include "VectorDispatcher.h" +#include "Dispatcher.h" namespace zeek::packet_analysis { -VectorDispatcher::~VectorDispatcher() +Dispatcher::~Dispatcher() { FreeValues(); } -bool VectorDispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) +bool Dispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) { // If the table has size 1 and the entry is nullptr, there was nothing added yet. Just add it. if ( table.size() == 1 && table[0] == nullptr ) @@ -55,7 +55,7 @@ bool VectorDispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, D return false; } -void VectorDispatcher::Register(const register_map& data) +void Dispatcher::Register(const register_map& data) { // Search smallest and largest identifier and resize vector const auto& lowest_new = @@ -77,7 +77,7 @@ void VectorDispatcher::Register(const register_map& data) } } -ValuePtr VectorDispatcher::Lookup(identifier_t identifier) const +ValuePtr Dispatcher::Lookup(identifier_t identifier) const { int64_t index = identifier - lowest_identifier; if ( index >= 0 && index < static_cast(table.size()) && table[index] != nullptr ) @@ -86,24 +86,24 @@ ValuePtr VectorDispatcher::Lookup(identifier_t identifier) const return nullptr; } -size_t VectorDispatcher::Size() const +size_t Dispatcher::Size() const { return std::count_if(table.begin(), table.end(), [](ValuePtr v) { return v != nullptr; }); } -void VectorDispatcher::Clear() +void Dispatcher::Clear() { FreeValues(); table.clear(); } -void VectorDispatcher::FreeValues() +void Dispatcher::FreeValues() { for ( auto& current : table ) current = nullptr; } -void VectorDispatcher::DumpDebug() const +void Dispatcher::DumpDebug() const { #ifdef DEBUG DBG_LOG(DBG_PACKET_ANALYSIS, " Dispatcher elements (used/total): %lu/%lu", Size(), table.size()); diff --git a/src/packet_analysis/dispatchers/Dispatcher.h b/src/packet_analysis/Dispatcher.h similarity index 56% rename from src/packet_analysis/dispatchers/Dispatcher.h rename to src/packet_analysis/Dispatcher.h index f7fca9e3b1..eb4b0c8c36 100644 --- a/src/packet_analysis/dispatchers/Dispatcher.h +++ b/src/packet_analysis/Dispatcher.h @@ -2,11 +2,9 @@ #pragma once -#include #include - +#include #include "Analyzer.h" -#include "Defines.h" namespace zeek::packet_analysis { @@ -31,21 +29,31 @@ using ValuePtr = std::shared_ptr; class Dispatcher { public: - virtual ~Dispatcher() = default; + Dispatcher() + : table(std::vector(1, nullptr)) + { } - virtual bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) = 0; - virtual void Register(const register_map& data) + ~Dispatcher(); + + bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher); + void Register(const register_map& data); + + ValuePtr Lookup(identifier_t identifier) const; + + size_t Size() const; + void Clear(); + void DumpDebug() const; + +private: + identifier_t lowest_identifier = 0; + std::vector table; + + void FreeValues(); + + inline identifier_t GetHighestIdentifier() const { - for ( auto& current : data ) - Register(current.first, current.second.first, current.second.second); + return lowest_identifier + table.size() - 1; } - - virtual ValuePtr Lookup(identifier_t identifier) const = 0; - - virtual size_t Size() const = 0; - virtual void Clear() = 0; - - virtual void DumpDebug() const = 0; - }; +}; } diff --git a/src/packet_analysis/Manager.cc b/src/packet_analysis/Manager.cc index 2f2d4f03b3..b9e17ac916 100644 --- a/src/packet_analysis/Manager.cc +++ b/src/packet_analysis/Manager.cc @@ -9,7 +9,7 @@ #include "NetVar.h" #include "plugin/Manager.h" #include "Analyzer.h" -#include "dispatchers/VectorDispatcher.h" +#include "Dispatcher.h" using namespace zeek::packet_analysis; @@ -266,7 +266,7 @@ DispatcherPtr Manager::GetDispatcher(Config& configuration, const std::string& d const auto& mappings = dispatcher_config->get().GetMappings(); - DispatcherPtr dispatcher = std::make_shared(); + DispatcherPtr dispatcher = std::make_shared(); dispatchers.emplace(dispatcher_name, dispatcher); for ( const auto& current_mapping : mappings ) diff --git a/src/packet_analysis/dispatchers/CMakeLists.txt b/src/packet_analysis/dispatchers/CMakeLists.txt deleted file mode 100644 index ea4183fef8..0000000000 --- a/src/packet_analysis/dispatchers/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -include(ZeekSubdir) - -include_directories(BEFORE - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_BINARY_DIR} -) - -set(dispatcher_SRCS - UniversalDispatcher.cc - VectorDispatcher.cc -) - -bro_add_subdir_library(llanalyzer_dispatcher ${dispatcher_SRCS}) diff --git a/src/packet_analysis/dispatchers/UniversalDispatcher.cc b/src/packet_analysis/dispatchers/UniversalDispatcher.cc deleted file mode 100644 index 517d8f109d..0000000000 --- a/src/packet_analysis/dispatchers/UniversalDispatcher.cc +++ /dev/null @@ -1,207 +0,0 @@ -// See the file "COPYING" in the main distribution directory for copyright. - -#include "UniversalDispatcher.h" - -namespace zeek::packet_analysis { - -UniversalDispatcher::UniversalDispatcher() : generator(rd()) - { - SetBins(2); - - table = std::vector(ONE << m, {0, nullptr}); - - // Initialize random engine - distribution_a = std::uniform_int_distribution(1, ~static_cast(0)); - distribution_b = std::uniform_int_distribution(0, (ONE << w_minus_m) - ONE); - - // Initialize random parameters - RandomizeAB(); - } - -UniversalDispatcher::~UniversalDispatcher() - { - FreeValues(); - } - -bool UniversalDispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) - { -#if DEBUG > 1 - std::shared_ptr deferred(nullptr, [=](...) { - std::cout << "Inserted " << identifier << std::endl; - }); -#endif - - uint64_t hashed_id = Hash(identifier); - if ( table[hashed_id].second == nullptr ) - { - // Free bin, insert the value - table[hashed_id] = std::make_pair(identifier, std::make_shared(analyzer, dispatcher)); - return true; - } - else if ( table[hashed_id].first != identifier ) - { - // The bin is not empty, but the content isn't the to-be-inserted identifier --> resolve collision - - // Create intermediate representation with the new element in it, then rehash with that data - std::vector intermediate = CreateIntermediate(); - intermediate.emplace_back(identifier, std::make_shared(analyzer, dispatcher)); - - // Try increasing the #bins until it works or it can't get any larger. - Rehash(intermediate); - return true; - } - - // Analyzer with this ID is already registered. - return false; - } - -void UniversalDispatcher::Register(const register_map& data) - { - // Analyzer already registered - for ( const auto& current : data ) - { - if ( table[Hash(current.first)].second != nullptr ) - throw std::invalid_argument("Analyzer " + std::to_string(current.first) + " already registered!"); - } - - // Create intermediate representation of current analyzer set, then add all new ones - std::vector intermediate = CreateIntermediate(); - for ( const auto& current : data ) - intermediate.emplace_back(current.first, std::make_shared(current.second.first, current.second.second)); - - Rehash(intermediate); - } - -ValuePtr UniversalDispatcher::Lookup(identifier_t identifier) const - { - uint64_t hashed_id = Hash(identifier); - - // The hashed_id can't be larger than the number of bins - assert(hashed_id < table.size() && "Hashed ID is outside of the hash table range!"); - - pair_t entry = table[hashed_id]; - if ( entry.second != nullptr && entry.first == identifier ) - return entry.second; - - return nullptr; - } - -size_t UniversalDispatcher::Size() const - { - size_t result = 0; - for ( const auto& current : table ) - { - if ( current.second != nullptr ) - result++; - } - return result; - } - -void UniversalDispatcher::Clear() - { - // Free all analyzers - FreeValues(); - - SetBins(2); - table = std::vector(ONE << m, {0, nullptr}); - RandomizeAB(); - } - -size_t UniversalDispatcher::BucketCount() - { - return table.size(); - } - -void UniversalDispatcher::Rehash() - { - // Intermediate representation is just the current table without nulls - Rehash(CreateIntermediate()); - } - -void UniversalDispatcher::DumpDebug() const - { -#ifdef DEBUG - DBG_LOG(DBG_PACKET_ANALYSIS, " Dispatcher elements (used/total): %lu/%lu", Size(), table.size()); - for ( size_t i = 0; i < table.size(); i++ ) - { - if ( table[i].second != nullptr ) - DBG_LOG(DBG_PACKET_ANALYSIS, " %#8x => %s, %p", table[i].first, table[i].second->analyzer->GetAnalyzerName(), table[i].second->dispatcher.get()); - } -#endif - } - -// ####################### -// ####### PRIVATE ####### -// ####################### - -void UniversalDispatcher::FreeValues() - { - for ( auto& current : table ) - current.second = nullptr; - } - -void UniversalDispatcher::Rehash(const std::vector& intermediate) - { - while ( ! FindCollisionFreeHashFunction(intermediate) ) - { - DBG_LOG(DBG_PACKET_ANALYSIS, "Rehashing did not work. Increasing #bins to %" PRIu64 " (%" PRIu64 " bit).", (uint64_t)std::pow(2, m + 1), m + 1); - SetBins(m + 1); - } - } - -bool UniversalDispatcher::FindCollisionFreeHashFunction(const std::vector& intermediate) - { - // Don't even try if the number of values is larger than the number of buckets - if ( ONE << m < intermediate.size() ) - return false; - - // Remember the hash function parameters to not break the table if rehashing doesn't work - uint64_t stored_a = a; - uint64_t stored_b = b; - - // Because the hash function hashes all values in the universe uniformly to m bins with probability 1/m - // we should at least try a multiple of #bins times. - for ( size_t i = 1; i <= (ONE << m); i++ ) - { - // Step 1: Re-randomize hash function parameters - RandomizeAB(); - - // Step 2: Create new table - std::vector new_table(ONE << m, {0, nullptr}); - - // Step 3: Try to insert all elements into the new table with the new hash function - bool finished = true; - for ( const auto& current : intermediate ) - { - uint64_t hashed_id = Hash(current.first); - assert(hashed_id < new_table.size()); - if ( new_table[hashed_id].second == nullptr ) - { - // Free bin, insert the value - new_table[hashed_id] = current; - } - else - { - // The bin is not empty which means there is a collision - // (there are no duplicates in the intermediate representation so that can't be the case) - finished = false; - break; - } - } - - // Step 4: If the inserting finished without collisions, overwrite the previous table and exit - if ( finished ) - { - DBG_LOG(DBG_PACKET_ANALYSIS, "Took %lu rehash(es) to resolve.", i); - table = new_table; - return true; - } - } - - // Finding a collision free hash function failed. Revert the hash function parameters. - a = stored_a; - b = stored_b; - return false; - } - -} diff --git a/src/packet_analysis/dispatchers/UniversalDispatcher.h b/src/packet_analysis/dispatchers/UniversalDispatcher.h deleted file mode 100644 index f235f438fd..0000000000 --- a/src/packet_analysis/dispatchers/UniversalDispatcher.h +++ /dev/null @@ -1,108 +0,0 @@ -// See the file "COPYING" in the main distribution directory for copyright. - -#pragma once - -#include -#include "Dispatcher.h" - -namespace zeek::packet_analysis { - -class UniversalDispatcher : public Dispatcher { -public: - UniversalDispatcher(); - ~UniversalDispatcher() override; - - bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) override; - void Register(const register_map& data) override; - ValuePtr Lookup(identifier_t identifier) const override; - size_t Size() const override; - void Clear() override; - - void DumpDebug() const override; - size_t BucketCount(); - - // Rehashes the hash table including re-randomization of the hash function. - void Rehash(); - -private: - using pair_t = std::pair; - static const uint64_t ONE = 1u; - - // Chosen random constants for the currently selected collision free random hash function - uint64_t a = 0; // Needs to be a random odd positive value < 2^(sizeof(uint64_t) * 8) - uint64_t b = 0; // Needs to be a random non-negative value < 2^(((sizeof(uint64_t) * 8) - M) - - // Current bits that define the number of bins. Initially 2 which means there are 2^2 = 4 bins. - uint64_t m = 2; - - // Current shift value which is the number of bits that are "insignificant" because of the universe size. - uint64_t w_minus_m = 0; - - // RNG - std::random_device rd; - std::mt19937_64 generator; - std::uniform_int_distribution distribution_a; - std::uniform_int_distribution distribution_b; - -// Debug -#if DEBUG > 0 - size_t nptr_counter = 0; - size_t mismatch_counter = 0; - size_t all_counter = 0; -#endif - - std::vector table; - - void FreeValues(); - - void Rehash(const std::vector& intermediate); - - /** - * Tries to find a collision free hash function with the current number of buckets. - * - * @param intermediate The key-value set to store in the hashtable. - * @return true, iff it found a collision-free hash function. - */ - bool FindCollisionFreeHashFunction(const std::vector& intermediate); - - [[nodiscard]] inline uint64_t Hash(const uint64_t value) const - { - return (a * value + b) >> w_minus_m; - } - - inline void RandomizeAB() - { - do { - a = distribution_a(generator); - } while ( a % 2 == 0 ); - - b = distribution_b(generator); - } - - inline void SetBins(uint64_t new_m) - { - if ( new_m > (sizeof(uint64_t) * 8) ) - throw std::runtime_error("Number of bits for bin count too large."); - - m = new_m; - w_minus_m = sizeof(uint64_t) * 8 - m; - distribution_b = std::uniform_int_distribution(0, ((uint64_t)(1u) << w_minus_m) - (uint64_t)(1u)); - } - - inline std::vector CreateIntermediate() - { - std::vector intermediate; - for ( const auto& current : table ) - { - if ( current.second != nullptr ) - { - assert(current.second->analyzer != nullptr); - intermediate.emplace_back(current.first, current.second); - } - } - return intermediate; - } - -}; - -} diff --git a/src/packet_analysis/dispatchers/VectorDispatcher.h b/src/packet_analysis/dispatchers/VectorDispatcher.h deleted file mode 100644 index ad7bbebe41..0000000000 --- a/src/packet_analysis/dispatchers/VectorDispatcher.h +++ /dev/null @@ -1,41 +0,0 @@ -// See the file "COPYING" in the main distribution directory for copyright. - -#pragma once - -#include -#include "Dispatcher.h" - -namespace zeek::packet_analysis { - -class VectorDispatcher : public Dispatcher { -public: - VectorDispatcher() - : table(std::vector(1, nullptr)) - { } - - ~VectorDispatcher() override; - - bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) override; - void Register(const register_map& data) override; - - ValuePtr Lookup(identifier_t identifier) const override; - - size_t Size() const override; - void Clear() override; - -protected: - void DumpDebug() const override; - -private: - identifier_t lowest_identifier = 0; - std::vector table; - - void FreeValues(); - - inline identifier_t GetHighestIdentifier() const - { - return lowest_identifier + table.size() - 1; - } -}; - -}