diff --git a/scripts/base/frameworks/sumstats/plugins/topk.bro b/scripts/base/frameworks/sumstats/plugins/topk.bro index f64e9fb18d..6107a252ae 100644 --- a/scripts/base/frameworks/sumstats/plugins/topk.bro +++ b/scripts/base/frameworks/sumstats/plugins/topk.bro @@ -22,3 +22,10 @@ hook observe_hook(r: Reducer, val: double, obs: Observation, rv: ResultVal) } +hook compose_resultvals_hook(result: ResultVal, rv1: ResultVal, rv2: ResultVal) + { + result$topk = topk_init(500); + + topk_merge(result$topk, rv1$topk); + topk_merge(result$topk, rv2$topk); + } diff --git a/src/Topk.cc b/src/Topk.cc index a31f49adf4..8ad2113235 100644 --- a/src/Topk.cc +++ b/src/Topk.cc @@ -71,6 +71,97 @@ TopkVal::~TopkVal() type = 0; } +void TopkVal::Merge(const TopkVal* value) + { + + if ( type == 0 ) + { + assert(numElements == 0); + type = value->type->Ref(); + } + else + if ( !same_type(type, value->type) ) + { + reporter->Error("Tried to merge top-k elements of differing types. Aborted"); + return; + } + + std::list::const_iterator it = value->buckets.begin(); + while ( it != value->buckets.end() ) + { + Bucket* b = *it; + uint64_t currcount = b->count; + std::list::const_iterator eit = b->elements.begin(); + + while ( eit != b->elements.end() ) + { + Element* e = *eit; + // lookup if we already know this one... + HashKey* key = GetHash(e->value); + Element* olde = (Element*) elementDict->Lookup(key); + + if ( olde == 0 ) + { + olde = new Element(); + olde->epsilon=0; + olde->value = e->value->Ref(); + // insert at bucket position 0 + if ( buckets.size() > 0 ) + { + assert (buckets.front()-> count > 0 ); + } + + Bucket* newbucket = new Bucket(); + newbucket->count = 0; + newbucket->bucketPos = buckets.insert(buckets.begin(), newbucket); + + olde->parent = newbucket; + newbucket->elements.insert(newbucket->elements.end(), olde); + + elementDict->Insert(key, olde); + numElements++; + + } + + // now that we are sure that the old element is present - increment epsilon + olde->epsilon += e->epsilon; + // and increment position... + IncrementCounter(olde, currcount); + delete key; + + eit++; + } + + it++; + } + + // now we have added everything. And our top-k table could be too big. + // prune everything... + + assert(size > 0); + while ( numElements > size ) + { + assert(buckets.size() > 0 ); + Bucket* b = buckets.front(); + assert(b->elements.size() > 0); + + Element* e = b->elements.front(); + HashKey* key = GetHash(e->value); + elementDict->RemoveEntry(key); + delete e; + + b->elements.pop_front(); + + if ( b->elements.size() == 0 ) + { + delete b; + buckets.pop_front(); + } + + numElements--; + } + + } bool TopkVal::DoSerialize(SerialInfo* info) const { @@ -318,7 +409,8 @@ void TopkVal::Encountered(Val* encountered) } -void TopkVal::IncrementCounter(Element* e) +// increment by count +void TopkVal::IncrementCounter(Element* e, unsigned int count) { Bucket* currBucket = e->parent; uint64 currcount = currBucket->count; @@ -330,11 +422,11 @@ void TopkVal::IncrementCounter(Element* e) bucketIter++; - if ( bucketIter != buckets.end() ) - { - if ( (*bucketIter)->count == currcount+1 ) - nextBucket = *bucketIter; - } + while ( bucketIter != buckets.end() && (*bucketIter)->count < currcount+count ) + bucketIter++; + + if ( bucketIter != buckets.end() && (*bucketIter)->count == currcount+count ) + nextBucket = *bucketIter; if ( nextBucket == 0 ) { @@ -342,7 +434,7 @@ void TopkVal::IncrementCounter(Element* e) // create it... Bucket* b = new Bucket(); - b->count = currcount+1; + b->count = currcount+count; std::list::iterator nextBucketPos = buckets.insert(bucketIter, b); b->bucketPos = nextBucketPos; // and give it the iterator we know now. diff --git a/src/Topk.h b/src/Topk.h index 0e38319380..30e87f7a99 100644 --- a/src/Topk.h +++ b/src/Topk.h @@ -40,12 +40,13 @@ public: VectorVal* getTopK(int k) const; // returns vector uint64_t getCount(Val* value) const; uint64_t getEpsilon(Val* value) const; + void Merge(const TopkVal* value); protected: TopkVal(); // for deserialize private: - void IncrementCounter(Element* e); + void IncrementCounter(Element* e, unsigned int count = 1); HashKey* GetHash(Val*) const; // this probably should go somewhere else. BroType* type; diff --git a/src/bro.bif b/src/bro.bif index e8e78c7872..b6f101c025 100644 --- a/src/bro.bif +++ b/src/bro.bif @@ -5684,3 +5684,16 @@ function topk_epsilon%(handle: opaque of topk, value: any%): count return new Val(h->getEpsilon(value), TYPE_COUNT); %} +function topk_merge%(handle1: opaque of topk, handle2: opaque of topk%): any + %{ + assert(handle1); + assert(handle2); + + Topk::TopkVal* h1 = (Topk::TopkVal*) handle1; + Topk::TopkVal* h2 = (Topk::TopkVal*) handle2; + + h1->Merge(h2); + + return 0; + %} + diff --git a/testing/btest/Baseline/bifs.topk/.stderr b/testing/btest/Baseline/bifs.topk/.stderr index f57e35ca51..f2bd316fd8 100644 --- a/testing/btest/Baseline/bifs.topk/.stderr +++ b/testing/btest/Baseline/bifs.topk/.stderr @@ -4,3 +4,7 @@ error: getCount for element that is not in top-k error: getEpsilon for element that is not in top-k error: getCount for element that is not in top-k error: getEpsilon for element that is not in top-k +error: getCount for element that is not in top-k +error: getEpsilon for element that is not in top-k +error: getCount for element that is not in top-k +error: getEpsilon for element that is not in top-k diff --git a/testing/btest/Baseline/bifs.topk/out b/testing/btest/Baseline/bifs.topk/out index 2116a30a12..8db55eeca8 100644 --- a/testing/btest/Baseline/bifs.topk/out +++ b/testing/btest/Baseline/bifs.topk/out @@ -35,3 +35,23 @@ 5 4 [c, e, d] +6 +0 +5 +0 +4 +0 +[c, e] +6 +0 +5 +0 +0 +0 +[c, e] +12 +0 +10 +0 +0 +0 diff --git a/testing/btest/bifs/topk.bro b/testing/btest/bifs/topk.bro index 9d936ce2f4..92a68999cc 100644 --- a/testing/btest/bifs/topk.bro +++ b/testing/btest/bifs/topk.bro @@ -87,6 +87,34 @@ event bro_init() topk_add(k1, "f"); s = topk_get_top(k1, 3); print s; + print topk_count(k1, "c"); + print topk_epsilon(k1, "c"); + print topk_count(k1, "e"); + print topk_epsilon(k1, "d"); + print topk_count(k1, "d"); + print topk_epsilon(k1, "d"); + local k3 = topk_init(2); + topk_merge(k3, k1); + + s = topk_get_top(k3, 3); + print s; + print topk_count(k3, "c"); + print topk_epsilon(k3, "c"); + print topk_count(k3, "e"); + print topk_epsilon(k3, "e"); + print topk_count(k3, "d"); + print topk_epsilon(k3, "d"); + + topk_merge(k3, k1); + + s = topk_get_top(k3, 3); + print s; + print topk_count(k3, "c"); + print topk_epsilon(k3, "c"); + print topk_count(k3, "e"); + print topk_epsilon(k3, "e"); + print topk_count(k3, "d"); + print topk_epsilon(k3, "d"); }