Move VectorDispatcher to be the only dispatcher

This commit is contained in:
Tim Wojtulewicz 2020-07-16 09:24:15 -07:00
parent d22481aef3
commit b46e600775
8 changed files with 36 additions and 397 deletions

View file

@ -6,10 +6,10 @@ include_directories(BEFORE
)
add_subdirectory(protocol)
add_subdirectory(dispatchers)
set(llanalyzer_SRCS
Analyzer.cc
Dispatcher.cc
Manager.cc
Component.cc
Tag.cc

View file

@ -2,16 +2,16 @@
#include <algorithm>
#include "VectorDispatcher.h"
#include "Dispatcher.h"
namespace zeek::packet_analysis {
VectorDispatcher::~VectorDispatcher()
Dispatcher::~Dispatcher()
{
FreeValues();
}
bool VectorDispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher)
bool Dispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher)
{
// If the table has size 1 and the entry is nullptr, there was nothing added yet. Just add it.
if ( table.size() == 1 && table[0] == nullptr )
@ -55,7 +55,7 @@ bool VectorDispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, D
return false;
}
void VectorDispatcher::Register(const register_map& data)
void Dispatcher::Register(const register_map& data)
{
// Search smallest and largest identifier and resize vector
const auto& lowest_new =
@ -77,7 +77,7 @@ void VectorDispatcher::Register(const register_map& data)
}
}
ValuePtr VectorDispatcher::Lookup(identifier_t identifier) const
ValuePtr Dispatcher::Lookup(identifier_t identifier) const
{
int64_t index = identifier - lowest_identifier;
if ( index >= 0 && index < static_cast<int64_t>(table.size()) && table[index] != nullptr )
@ -86,24 +86,24 @@ ValuePtr VectorDispatcher::Lookup(identifier_t identifier) const
return nullptr;
}
size_t VectorDispatcher::Size() const
size_t Dispatcher::Size() const
{
return std::count_if(table.begin(), table.end(), [](ValuePtr v) { return v != nullptr; });
}
void VectorDispatcher::Clear()
void Dispatcher::Clear()
{
FreeValues();
table.clear();
}
void VectorDispatcher::FreeValues()
void Dispatcher::FreeValues()
{
for ( auto& current : table )
current = nullptr;
}
void VectorDispatcher::DumpDebug() const
void Dispatcher::DumpDebug() const
{
#ifdef DEBUG
DBG_LOG(DBG_PACKET_ANALYSIS, " Dispatcher elements (used/total): %lu/%lu", Size(), table.size());

View file

@ -2,11 +2,9 @@
#pragma once
#include <map>
#include <utility>
#include <vector>
#include "Analyzer.h"
#include "Defines.h"
namespace zeek::packet_analysis {
@ -31,21 +29,31 @@ using ValuePtr = std::shared_ptr<Value>;
class Dispatcher {
public:
virtual ~Dispatcher() = default;
Dispatcher()
: table(std::vector<ValuePtr>(1, nullptr))
{ }
virtual bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) = 0;
virtual void Register(const register_map& data)
~Dispatcher();
bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher);
void Register(const register_map& data);
ValuePtr Lookup(identifier_t identifier) const;
size_t Size() const;
void Clear();
void DumpDebug() const;
private:
identifier_t lowest_identifier = 0;
std::vector<ValuePtr> table;
void FreeValues();
inline identifier_t GetHighestIdentifier() const
{
for ( auto& current : data )
Register(current.first, current.second.first, current.second.second);
return lowest_identifier + table.size() - 1;
}
virtual ValuePtr Lookup(identifier_t identifier) const = 0;
virtual size_t Size() const = 0;
virtual void Clear() = 0;
virtual void DumpDebug() const = 0;
};
}

View file

@ -9,7 +9,7 @@
#include "NetVar.h"
#include "plugin/Manager.h"
#include "Analyzer.h"
#include "dispatchers/VectorDispatcher.h"
#include "Dispatcher.h"
using namespace zeek::packet_analysis;
@ -266,7 +266,7 @@ DispatcherPtr Manager::GetDispatcher(Config& configuration, const std::string& d
const auto& mappings = dispatcher_config->get().GetMappings();
DispatcherPtr dispatcher = std::make_shared<VectorDispatcher>();
DispatcherPtr dispatcher = std::make_shared<Dispatcher>();
dispatchers.emplace(dispatcher_name, dispatcher);
for ( const auto& current_mapping : mappings )

View file

@ -1,13 +0,0 @@
include(ZeekSubdir)
include_directories(BEFORE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
)
set(dispatcher_SRCS
UniversalDispatcher.cc
VectorDispatcher.cc
)
bro_add_subdir_library(llanalyzer_dispatcher ${dispatcher_SRCS})

View file

@ -1,207 +0,0 @@
// See the file "COPYING" in the main distribution directory for copyright.
#include "UniversalDispatcher.h"
namespace zeek::packet_analysis {
UniversalDispatcher::UniversalDispatcher() : generator(rd())
{
SetBins(2);
table = std::vector<pair_t>(ONE << m, {0, nullptr});
// Initialize random engine
distribution_a = std::uniform_int_distribution<uint64_t>(1, ~static_cast<uint64_t>(0));
distribution_b = std::uniform_int_distribution<uint64_t>(0, (ONE << w_minus_m) - ONE);
// Initialize random parameters
RandomizeAB();
}
UniversalDispatcher::~UniversalDispatcher()
{
FreeValues();
}
bool UniversalDispatcher::Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher)
{
#if DEBUG > 1
std::shared_ptr<void> deferred(nullptr, [=](...) {
std::cout << "Inserted " << identifier << std::endl;
});
#endif
uint64_t hashed_id = Hash(identifier);
if ( table[hashed_id].second == nullptr )
{
// Free bin, insert the value
table[hashed_id] = std::make_pair(identifier, std::make_shared<Value>(analyzer, dispatcher));
return true;
}
else if ( table[hashed_id].first != identifier )
{
// The bin is not empty, but the content isn't the to-be-inserted identifier --> resolve collision
// Create intermediate representation with the new element in it, then rehash with that data
std::vector<pair_t> intermediate = CreateIntermediate();
intermediate.emplace_back(identifier, std::make_shared<Value>(analyzer, dispatcher));
// Try increasing the #bins until it works or it can't get any larger.
Rehash(intermediate);
return true;
}
// Analyzer with this ID is already registered.
return false;
}
void UniversalDispatcher::Register(const register_map& data)
{
// Analyzer already registered
for ( const auto& current : data )
{
if ( table[Hash(current.first)].second != nullptr )
throw std::invalid_argument("Analyzer " + std::to_string(current.first) + " already registered!");
}
// Create intermediate representation of current analyzer set, then add all new ones
std::vector<pair_t> intermediate = CreateIntermediate();
for ( const auto& current : data )
intermediate.emplace_back(current.first, std::make_shared<Value>(current.second.first, current.second.second));
Rehash(intermediate);
}
ValuePtr UniversalDispatcher::Lookup(identifier_t identifier) const
{
uint64_t hashed_id = Hash(identifier);
// The hashed_id can't be larger than the number of bins
assert(hashed_id < table.size() && "Hashed ID is outside of the hash table range!");
pair_t entry = table[hashed_id];
if ( entry.second != nullptr && entry.first == identifier )
return entry.second;
return nullptr;
}
size_t UniversalDispatcher::Size() const
{
size_t result = 0;
for ( const auto& current : table )
{
if ( current.second != nullptr )
result++;
}
return result;
}
void UniversalDispatcher::Clear()
{
// Free all analyzers
FreeValues();
SetBins(2);
table = std::vector<pair_t>(ONE << m, {0, nullptr});
RandomizeAB();
}
size_t UniversalDispatcher::BucketCount()
{
return table.size();
}
void UniversalDispatcher::Rehash()
{
// Intermediate representation is just the current table without nulls
Rehash(CreateIntermediate());
}
void UniversalDispatcher::DumpDebug() const
{
#ifdef DEBUG
DBG_LOG(DBG_PACKET_ANALYSIS, " Dispatcher elements (used/total): %lu/%lu", Size(), table.size());
for ( size_t i = 0; i < table.size(); i++ )
{
if ( table[i].second != nullptr )
DBG_LOG(DBG_PACKET_ANALYSIS, " %#8x => %s, %p", table[i].first, table[i].second->analyzer->GetAnalyzerName(), table[i].second->dispatcher.get());
}
#endif
}
// #######################
// ####### PRIVATE #######
// #######################
void UniversalDispatcher::FreeValues()
{
for ( auto& current : table )
current.second = nullptr;
}
void UniversalDispatcher::Rehash(const std::vector<pair_t>& intermediate)
{
while ( ! FindCollisionFreeHashFunction(intermediate) )
{
DBG_LOG(DBG_PACKET_ANALYSIS, "Rehashing did not work. Increasing #bins to %" PRIu64 " (%" PRIu64 " bit).", (uint64_t)std::pow(2, m + 1), m + 1);
SetBins(m + 1);
}
}
bool UniversalDispatcher::FindCollisionFreeHashFunction(const std::vector<pair_t>& intermediate)
{
// Don't even try if the number of values is larger than the number of buckets
if ( ONE << m < intermediate.size() )
return false;
// Remember the hash function parameters to not break the table if rehashing doesn't work
uint64_t stored_a = a;
uint64_t stored_b = b;
// Because the hash function hashes all values in the universe uniformly to m bins with probability 1/m
// we should at least try a multiple of #bins times.
for ( size_t i = 1; i <= (ONE << m); i++ )
{
// Step 1: Re-randomize hash function parameters
RandomizeAB();
// Step 2: Create new table
std::vector<pair_t> new_table(ONE << m, {0, nullptr});
// Step 3: Try to insert all elements into the new table with the new hash function
bool finished = true;
for ( const auto& current : intermediate )
{
uint64_t hashed_id = Hash(current.first);
assert(hashed_id < new_table.size());
if ( new_table[hashed_id].second == nullptr )
{
// Free bin, insert the value
new_table[hashed_id] = current;
}
else
{
// The bin is not empty which means there is a collision
// (there are no duplicates in the intermediate representation so that can't be the case)
finished = false;
break;
}
}
// Step 4: If the inserting finished without collisions, overwrite the previous table and exit
if ( finished )
{
DBG_LOG(DBG_PACKET_ANALYSIS, "Took %lu rehash(es) to resolve.", i);
table = new_table;
return true;
}
}
// Finding a collision free hash function failed. Revert the hash function parameters.
a = stored_a;
b = stored_b;
return false;
}
}

View file

@ -1,108 +0,0 @@
// See the file "COPYING" in the main distribution directory for copyright.
#pragma once
#include <random>
#include "Dispatcher.h"
namespace zeek::packet_analysis {
class UniversalDispatcher : public Dispatcher {
public:
UniversalDispatcher();
~UniversalDispatcher() override;
bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) override;
void Register(const register_map& data) override;
ValuePtr Lookup(identifier_t identifier) const override;
size_t Size() const override;
void Clear() override;
void DumpDebug() const override;
size_t BucketCount();
// Rehashes the hash table including re-randomization of the hash function.
void Rehash();
private:
using pair_t = std::pair<identifier_t, ValuePtr>;
static const uint64_t ONE = 1u;
// Chosen random constants for the currently selected collision free random hash function
uint64_t a = 0; // Needs to be a random odd positive value < 2^(sizeof(uint64_t) * 8)
uint64_t b = 0; // Needs to be a random non-negative value < 2^(((sizeof(uint64_t) * 8) - M)
// Current bits that define the number of bins. Initially 2 which means there are 2^2 = 4 bins.
uint64_t m = 2;
// Current shift value which is the number of bits that are "insignificant" because of the universe size.
uint64_t w_minus_m = 0;
// RNG
std::random_device rd;
std::mt19937_64 generator;
std::uniform_int_distribution<uint64_t> distribution_a;
std::uniform_int_distribution<uint64_t> distribution_b;
// Debug
#if DEBUG > 0
size_t nptr_counter = 0;
size_t mismatch_counter = 0;
size_t all_counter = 0;
#endif
std::vector<pair_t> table;
void FreeValues();
void Rehash(const std::vector<pair_t>& intermediate);
/**
* Tries to find a collision free hash function with the current number of buckets.
*
* @param intermediate The key-value set to store in the hashtable.
* @return true, iff it found a collision-free hash function.
*/
bool FindCollisionFreeHashFunction(const std::vector<pair_t>& intermediate);
[[nodiscard]] inline uint64_t Hash(const uint64_t value) const
{
return (a * value + b) >> w_minus_m;
}
inline void RandomizeAB()
{
do {
a = distribution_a(generator);
} while ( a % 2 == 0 );
b = distribution_b(generator);
}
inline void SetBins(uint64_t new_m)
{
if ( new_m > (sizeof(uint64_t) * 8) )
throw std::runtime_error("Number of bits for bin count too large.");
m = new_m;
w_minus_m = sizeof(uint64_t) * 8 - m;
distribution_b = std::uniform_int_distribution<uint64_t>(0, ((uint64_t)(1u) << w_minus_m) - (uint64_t)(1u));
}
inline std::vector<pair_t> CreateIntermediate()
{
std::vector<pair_t> intermediate;
for ( const auto& current : table )
{
if ( current.second != nullptr )
{
assert(current.second->analyzer != nullptr);
intermediate.emplace_back(current.first, current.second);
}
}
return intermediate;
}
};
}

View file

@ -1,41 +0,0 @@
// See the file "COPYING" in the main distribution directory for copyright.
#pragma once
#include <utility>
#include "Dispatcher.h"
namespace zeek::packet_analysis {
class VectorDispatcher : public Dispatcher {
public:
VectorDispatcher()
: table(std::vector<ValuePtr>(1, nullptr))
{ }
~VectorDispatcher() override;
bool Register(identifier_t identifier, AnalyzerPtr analyzer, DispatcherPtr dispatcher) override;
void Register(const register_map& data) override;
ValuePtr Lookup(identifier_t identifier) const override;
size_t Size() const override;
void Clear() override;
protected:
void DumpDebug() const override;
private:
identifier_t lowest_identifier = 0;
std::vector<ValuePtr> table;
void FreeValues();
inline identifier_t GetHighestIdentifier() const
{
return lowest_identifier + table.size() - 1;
}
};
}