diff --git a/src/BloomFilter.cc b/src/BloomFilter.cc index 3c7bac80f1..889c7bafe1 100644 --- a/src/BloomFilter.cc +++ b/src/BloomFilter.cc @@ -70,8 +70,13 @@ size_t BasicBloomFilter::K(size_t cells, size_t capacity) BasicBloomFilter* BasicBloomFilter::Merge(const BasicBloomFilter* x, const BasicBloomFilter* y) { - // TODO: Ensure that x and y use the same Hasher before proceeding. + if ( ! x->hasher_->Equals(y->hasher_) ) + { + reporter->InternalError("incompatible hashers during Bloom filter merge"); + return NULL; + } BasicBloomFilter* result = new BasicBloomFilter(); + result->hasher_ = x->hasher_->Clone(); result->bits_ = new BitVector(*x->bits_ | *y->bits_); return result; } @@ -119,10 +124,17 @@ size_t BasicBloomFilter::CountImpl(const Hasher::digest_vector& h) const CountingBloomFilter* CountingBloomFilter::Merge(const CountingBloomFilter* x, const CountingBloomFilter* y) -{ - assert(! "not yet implemented"); - return NULL; -} + { + if ( ! x->hasher_->Equals(y->hasher_) ) + { + reporter->InternalError("incompatible hashers during Bloom filter merge"); + return NULL; + } + CountingBloomFilter* result = new CountingBloomFilter(); + result->hasher_ = x->hasher_->Clone(); + result->cells_ = new CounterVector(*x->cells_ | *y->cells_); + return result; + } CountingBloomFilter::CountingBloomFilter() : cells_(NULL) diff --git a/src/BloomFilter.h b/src/BloomFilter.h index 92f15c6070..070aa2dc25 100644 --- a/src/BloomFilter.h +++ b/src/BloomFilter.h @@ -57,7 +57,6 @@ protected: virtual void AddImpl(const Hasher::digest_vector& hashes) = 0; virtual size_t CountImpl(const Hasher::digest_vector& hashes) const = 0; -private: const Hasher* hasher_; }; diff --git a/src/CounterVector.cc b/src/CounterVector.cc index 75c62b208a..cf3083de9e 100644 --- a/src/CounterVector.cc +++ b/src/CounterVector.cc @@ -10,6 +10,12 @@ CounterVector::CounterVector(size_t width, size_t cells) { } +CounterVector::CounterVector(const CounterVector& other) + : bits_(new BitVector(*other.bits_)), + width_(other.width_) + { + } + CounterVector::~CounterVector() { delete bits_; diff --git a/src/CounterVector.h b/src/CounterVector.h index 4ab221ff6b..eced5956d4 100644 --- a/src/CounterVector.h +++ b/src/CounterVector.h @@ -9,6 +9,7 @@ class BitVector; * A vector of counters, each of which have a fixed number of bits. */ class CounterVector : public SerialObj { + CounterVector& operator=(const CounterVector&); public: typedef size_t size_type; typedef uint64 count_type; @@ -24,6 +25,13 @@ public: */ CounterVector(size_t width, size_t cells = 1024); + /** + * Copy-constructs a counter vector. + * + * @param other The counter vector to copy. + */ + CounterVector(const CounterVector& other); + ~CounterVector(); /** diff --git a/src/Hasher.cc b/src/Hasher.cc index 7a8d9a67e0..2a889c7e09 100644 --- a/src/Hasher.cc +++ b/src/Hasher.cc @@ -64,7 +64,7 @@ DefaultHasher* DefaultHasher::Clone() const return new DefaultHasher(*this); } -bool DefaultHasher::Equals(const Hasher* other) const /* final */ +bool DefaultHasher::Equals(const Hasher* other) const { if ( typeid(*this) != typeid(*other) ) return false; @@ -94,7 +94,7 @@ DoubleHasher* DoubleHasher::Clone() const return new DoubleHasher(*this); } -bool DoubleHasher::Equals(const Hasher* other) const /* final */ +bool DoubleHasher::Equals(const Hasher* other) const { if ( typeid(*this) != typeid(*other) ) return false; diff --git a/src/OpaqueVal.cc b/src/OpaqueVal.cc index 5a673c4a40..36038d679a 100644 --- a/src/OpaqueVal.cc +++ b/src/OpaqueVal.cc @@ -1,6 +1,5 @@ #include "OpaqueVal.h" -#include "BloomFilter.h" #include "NetVar.h" #include "Reporter.h" #include "Serializer.h" @@ -587,6 +586,7 @@ BloomFilterVal* BloomFilterVal::Merge(const BloomFilterVal* x, else if ( (result = DoMerge(x, y)) ) return result; + reporter->InternalError("failed to merge Bloom filters"); return NULL; } diff --git a/src/OpaqueVal.h b/src/OpaqueVal.h index 2362fdacfc..22c3dbfade 100644 --- a/src/OpaqueVal.h +++ b/src/OpaqueVal.h @@ -3,6 +3,7 @@ #ifndef OPAQUEVAL_H #define OPAQUEVAL_H +#include "BloomFilter.h" #include "RandTest.h" #include "Val.h" #include "digest.h" @@ -137,9 +138,23 @@ private: static BloomFilterVal* DoMerge(const BloomFilterVal* x, const BloomFilterVal* y) { - const T* a = dynamic_cast(x->bloom_filter_); - const T* b = dynamic_cast(y->bloom_filter_); - return a && b ? new BloomFilterVal(T::Merge(a, b)) : NULL; + if ( typeid(*x->bloom_filter_) != typeid(*y->bloom_filter_) ) + { + reporter->InternalError("cannot merge different Bloom filter types"); + return NULL; + } + if ( typeid(T) != typeid(*x->bloom_filter_) ) + return NULL; + const T* a = static_cast(x->bloom_filter_); + const T* b = static_cast(y->bloom_filter_); + BloomFilterVal* merged = new BloomFilterVal(T::Merge(a, b)); + assert(merged); + if ( ! merged->Typify(x->Type()) ) + { + reporter->InternalError("failed to set type on merged Bloom filter"); + return NULL; + } + return merged; } BroType* type_; diff --git a/testing/btest/Baseline/bifs.bloomfilter/output b/testing/btest/Baseline/bifs.bloomfilter/output index 80847a81b9..4fe2ae1ecc 100644 --- a/testing/btest/Baseline/bifs.bloomfilter/output +++ b/testing/btest/Baseline/bifs.bloomfilter/output @@ -7,8 +7,15 @@ 1 1 1 +1 +1 +1 +1 2 3 3 2 3 +3 +3 +2 diff --git a/testing/btest/bifs/bloomfilter.bro b/testing/btest/bifs/bloomfilter.bro index ab0bf86c22..f69ddbda0c 100644 --- a/testing/btest/bifs/bloomfilter.bro +++ b/testing/btest/bifs/bloomfilter.bro @@ -35,11 +35,21 @@ function test_basic_bloom_filter() # Invalid parameters. local bf_bug0 = bloomfilter_basic_init(-0.5, 42); local bf_bug1 = bloomfilter_basic_init(1.1, 42); + + # Merging + local bf_cnt2 = bloomfilter_basic_init(0.1, 1000); + bloomfilter_add(bf_cnt2, 42); + bloomfilter_add(bf_cnt, 100); + local bf_merged = bloomfilter_merge(bf_cnt, bf_cnt2); + print bloomfilter_lookup(bf_merged, 42); + print bloomfilter_lookup(bf_merged, 84); + print bloomfilter_lookup(bf_merged, 100); + print bloomfilter_lookup(bf_merged, 168); } function test_counting_bloom_filter() { - local bf = bloomfilter_counting_init(3, 16, 3); + local bf = bloomfilter_counting_init(3, 32, 3); bloomfilter_add(bf, "foo"); print bloomfilter_lookup(bf, "foo"); # 1 bloomfilter_add(bf, "foo"); @@ -49,10 +59,21 @@ function test_counting_bloom_filter() bloomfilter_add(bf, "foo"); print bloomfilter_lookup(bf, "foo"); # still 3 + bloomfilter_add(bf, "bar"); bloomfilter_add(bf, "bar"); print bloomfilter_lookup(bf, "bar"); # 2 print bloomfilter_lookup(bf, "foo"); # still 3 + + # Merging + local bf2 = bloomfilter_counting_init(3, 32, 3); + bloomfilter_add(bf2, "baz"); + bloomfilter_add(bf2, "baz"); + bloomfilter_add(bf2, "bar"); + local bf_merged = bloomfilter_merge(bf, bf2); + print bloomfilter_lookup(bf_merged, "foo"); + print bloomfilter_lookup(bf_merged, "bar"); + print bloomfilter_lookup(bf_merged, "baz"); } event bro_init()