diff --git a/src/BloomFilter.cc b/src/BloomFilter.cc index 6873815f69..4787bef0f0 100644 --- a/src/BloomFilter.cc +++ b/src/BloomFilter.cc @@ -1,23 +1,130 @@ #include "BloomFilter.h" +#include +#include "Serializer.h" + +// Backport C++11's std::round(). +namespace { +template +T round(double x) { return (x > 0.0) ? (x + 0.5) : (x - 0.5); } +} // namespace + + +IMPLEMENT_SERIAL(CounterVector, SER_COUNTERVECTOR) + +bool CounterVector::DoSerialize(SerialInfo* info) const + { + DO_SERIALIZE(SER_COUNTERVECTOR, SerialObj); + if ( ! SERIALIZE(&bits_) ) + return false; + return SERIALIZE(static_cast(width_)); + } + +bool CounterVector::DoUnserialize(UnserialInfo* info) + { + DO_UNSERIALIZE(SerialObj); + return false; + // TODO: Ask Robin how to unserialize non-pointer members. + //if ( ! UNSERIALIZE(&bits_) ) + // return false; + uint64 width; + if ( ! UNSERIALIZE(&width) ) + return false; + width_ = static_cast(width); + return true; + } + + HashPolicy::HashVector DefaultHashing::Hash(const void* x, size_t n) const { - HashVector h(k(), 0); + HashVector h(K(), 0); for ( size_t i = 0; i < h.size(); ++i ) h[i] = hashers_[i](x, n); return h; } + HashPolicy::HashVector DoubleHashing::Hash(const void* x, size_t n) const { HashType h1 = hasher1_(x); HashType h2 = hasher2_(x); - HashVector h(k(), 0); + HashVector h(K(), 0); for ( size_t i = 0; i < h.size(); ++i ) h[i] = h1 + i * h2; return h; } +bool BloomFilter::Serialize(SerialInfo* info) const + { + return SerialObj::Serialize(info); + } + +BloomFilter* BloomFilter::Unserialize(UnserialInfo* info) + { + return reinterpret_cast( + SerialObj::Unserialize(info, SER_BLOOMFILTER)); + } + +// FIXME: should abstract base classes also have IMPLEMENT_SERIAL? +//IMPLEMENT_SERIAL(BloomFilter, SER_BLOOMFILTER) + +bool BloomFilter::DoSerialize(SerialInfo* info) const + { + DO_SERIALIZE(SER_BLOOMFILTER, SerialObj); + // TODO: Make the hash policy serializable. + //if ( ! SERIALIZE(hash_) ) + // return false; + return SERIALIZE(static_cast(elements_)); + } + +bool BloomFilter::DoUnserialize(UnserialInfo* info) + { + DO_UNSERIALIZE(SerialObj); + // TODO: Make the hash policy serializable. + //if ( ! hash_ = HashPolicy::Unserialize(info) ) + // return false; + uint64 elements; + if ( UNSERIALIZE(&elements) ) + return false; + elements_ = static_cast(elements); + return true; + } + +size_t BasicBloomFilter::Cells(double fp, size_t capacity) + { + double ln2 = std::log(2); + return std::ceil(-(capacity * std::log(fp) / ln2 / ln2)); + } + +size_t BasicBloomFilter::K(size_t cells, size_t capacity) + { + double frac = static_cast(cells) / static_cast(capacity); + return round(frac * std::log(2)); + } + +BasicBloomFilter::BasicBloomFilter(size_t cells, HashPolicy* hash) + : BloomFilter(hash), bits_(cells) + { + } + +IMPLEMENT_SERIAL(BasicBloomFilter, SER_BASICBLOOMFILTER) + +bool BasicBloomFilter::DoSerialize(SerialInfo* info) const + { + DO_SERIALIZE(SER_BASICBLOOMFILTER, BloomFilter); + // TODO: Make the hash policy serializable. + //if ( ! SERIALIZE(&bits_) ) + // return false; + return true; + } + +bool BasicBloomFilter::DoUnserialize(UnserialInfo* info) + { + DO_UNSERIALIZE(BloomFilter); + // TODO: Non-pointer member deserialization? + return true; + } + void BasicBloomFilter::AddImpl(const HashPolicy::HashVector& h) { for ( size_t i = 0; i < h.size(); ++i ) @@ -31,3 +138,23 @@ size_t BasicBloomFilter::CountImpl(const HashPolicy::HashVector& h) const return 0; return 1; } + + +void CountingBloomFilter::AddImpl(const HashPolicy::HashVector& h) + { + for ( size_t i = 0; i < h.size(); ++i ) + cells_.Increment(h[i] % h.size(), 1); + } + +size_t CountingBloomFilter::CountImpl(const HashPolicy::HashVector& h) const + { + CounterVector::size_type min = + std::numeric_limits::max(); + for ( size_t i = 0; i < h.size(); ++i ) + { + CounterVector::size_type cnt = cells_.Count(h[i] % h.size()); + if ( cnt < min ) + min = cnt; + } + return min; + } diff --git a/src/BloomFilter.h b/src/BloomFilter.h index dca4eff2bd..82948f30ec 100644 --- a/src/BloomFilter.h +++ b/src/BloomFilter.h @@ -65,7 +65,7 @@ public: protected: DECLARE_SERIAL(CounterVector); - CounterVector(); + CounterVector() { } private: BitVector bits_; @@ -82,7 +82,7 @@ public: typedef std::vector HashVector; virtual ~HashPolicy() { } - size_t k() const { return k_; } + size_t K() const { return k_; } virtual HashVector Hash(const void* x, size_t n) const = 0; protected: @@ -130,7 +130,7 @@ private: }; /** - * The *double-hashing* policy. Uses a linear combination of 2 hash functions. + * The *double-hashing* policy. Uses a linear combination of two hash functions. */ class DoubleHashing : public HashPolicy { public: @@ -185,25 +185,20 @@ public: return elements_; } -protected: - /** - * Default-constructs a Bloom filter. - */ - BloomFilter(); + bool Serialize(SerialInfo* info) const; + static BloomFilter* Unserialize(UnserialInfo* info); - /** - * Constructs a BloomFilter. - * @param hash The hashing policy. - */ - BloomFilter(HashPolicy* hash); +protected: + DECLARE_SERIAL(BloomFilter); + + BloomFilter() { }; + BloomFilter(HashPolicy* hash) : hash_(hash) { } virtual void AddImpl(const HashPolicy::HashVector& hashes) = 0; - virtual size_t CountImpl(const HashPolicy::HashVector& hashes) const = 0; private: - HashPolicy* hash_; // Owned by *this. - + HashPolicy* hash_; size_t elements_; }; @@ -212,12 +207,17 @@ private: */ class BasicBloomFilter : public BloomFilter { public: - BasicBloomFilter(); - BasicBloomFilter(HashPolicy* hash); + static size_t Cells(double fp, size_t capacity); + static size_t K(size_t cells, size_t capacity); + + BasicBloomFilter(size_t cells, HashPolicy* hash); protected: - virtual void AddImpl(const HashPolicy::HashVector& h); + DECLARE_SERIAL(BasicBloomFilter); + BasicBloomFilter() { } + + virtual void AddImpl(const HashPolicy::HashVector& h); virtual size_t CountImpl(const HashPolicy::HashVector& h) const; private: @@ -232,10 +232,11 @@ public: CountingBloomFilter(unsigned width, HashPolicy* hash); protected: + DECLARE_SERIAL(CountingBloomFilter); + CountingBloomFilter(); virtual void AddImpl(const HashPolicy::HashVector& h); - virtual size_t CountImpl(const HashPolicy::HashVector& h) const; private: diff --git a/src/NetVar.cc b/src/NetVar.cc index 3a23e4c9fa..d8c2192af7 100644 --- a/src/NetVar.cc +++ b/src/NetVar.cc @@ -244,6 +244,7 @@ OpaqueType* md5_type; OpaqueType* sha1_type; OpaqueType* sha256_type; OpaqueType* entropy_type; +OpaqueType* bloomfilter_type; #include "const.bif.netvar_def" #include "types.bif.netvar_def" @@ -310,6 +311,7 @@ void init_general_global_var() sha1_type = new OpaqueType("sha1"); sha256_type = new OpaqueType("sha256"); entropy_type = new OpaqueType("entropy"); + bloomfilter_type = new OpaqueType("bloomfilter"); } void init_net_var() diff --git a/src/OpaqueVal.cc b/src/OpaqueVal.cc index 19346e52f2..a5fb65f53b 100644 --- a/src/OpaqueVal.cc +++ b/src/OpaqueVal.cc @@ -1,4 +1,6 @@ #include "OpaqueVal.h" + +#include "BloomFilter.h" #include "NetVar.h" #include "Reporter.h" #include "Serializer.h" @@ -515,3 +517,24 @@ bool EntropyVal::DoUnserialize(UnserialInfo* info) return true; } + +BloomFilterVal::BloomFilterVal(OpaqueType* t) : OpaqueVal(t) + { + } + +IMPLEMENT_SERIAL(BloomFilterVal, SER_BLOOMFILTER_VAL); + +bool BloomFilterVal::DoSerialize(SerialInfo* info) const + { + DO_SERIALIZE(SER_BLOOMFILTER_VAL, OpaqueVal); + // TODO: implement. + return true; + } + +bool BloomFilterVal::DoUnserialize(UnserialInfo* info) + { + DO_UNSERIALIZE(OpaqueVal); + // TODO: implement. + return true; + } + diff --git a/src/OpaqueVal.h b/src/OpaqueVal.h index 78fa5da5e9..1c9c0361cc 100644 --- a/src/OpaqueVal.h +++ b/src/OpaqueVal.h @@ -7,6 +7,8 @@ #include "Val.h" #include "digest.h" +class BloomFilter; + class HashVal : public OpaqueVal { public: virtual bool IsValid() const; @@ -107,4 +109,18 @@ private: RandTest state; }; +class BloomFilterVal : public OpaqueVal { +public: + BloomFilterVal(); + +protected: + friend class Val; + BloomFilterVal(OpaqueType* t); + + DECLARE_SERIAL(BloomFilterVal); + +private: + BloomFilter* bloom_filter_; +}; + #endif diff --git a/src/SerialTypes.h b/src/SerialTypes.h index c9c0c34a33..171113ab6a 100644 --- a/src/SerialTypes.h +++ b/src/SerialTypes.h @@ -50,6 +50,9 @@ SERIAL_IS_BO(CASE, 0x1200) SERIAL_IS(LOCATION, 0x1300) SERIAL_IS(RE_MATCHER, 0x1400) SERIAL_IS(BITVECTOR, 0x1500) +SERIAL_IS(COUNTERVECTOR, 0xa000) +SERIAL_IS(BLOOMFILTER, 0xa100) +SERIAL_IS(BASICBLOOMFILTER, 0xa200) // These are the externally visible types. const SerialType SER_NONE = 0; @@ -105,6 +108,7 @@ SERIAL_VAL(MD5_VAL, 16) SERIAL_VAL(SHA1_VAL, 17) SERIAL_VAL(SHA256_VAL, 18) SERIAL_VAL(ENTROPY_VAL, 19) +SERIAL_VAL(BLOOMFILTER_VAL, 20) #define SERIAL_EXPR(name, val) SERIAL_CONST(name, val, EXPR) SERIAL_EXPR(EXPR, 1) @@ -204,5 +208,8 @@ SERIAL_CONST2(CASE) SERIAL_CONST2(LOCATION) SERIAL_CONST2(RE_MATCHER) SERIAL_CONST2(BITVECTOR) +SERIAL_CONST2(COUNTERVECTOR) +SERIAL_CONST2(BLOOMFILTER) +SERIAL_CONST2(BASICBLOOMFILTER) #endif