diff --git a/src/Expr.cc b/src/Expr.cc index 32445ccd69..18deec804c 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -2382,6 +2382,18 @@ IndexExpr::IndexExpr(ExprPtr arg_op1, ListExprPtr arg_op2, bool arg_is_slice, bo if ( IsError() ) return; + if ( op1->GetType()->Tag() == TYPE_TABLE ) { // Check for a table[pattern] being indexed by a string + auto table_type = op1->GetType()->AsTableType(); + auto& it = table_type->GetIndexTypes(); + auto& rhs_type = op2->GetType()->AsTypeList()->GetTypes(); + if ( it.size() == 1 && it[0]->Tag() == TYPE_PATTERN && table_type->Yield() && rhs_type.size() == 1 && + rhs_type[0]->Tag() == TYPE_STRING ) { + is_pattern_table = true; + SetType(make_intrusive(op1->GetType()->Yield())); + return; + } + } + int match_type = op1->GetType()->MatchesIndex(op2->AsListExpr()); if ( match_type == DOES_NOT_MATCH_INDEX ) { @@ -2532,7 +2544,12 @@ ValPtr IndexExpr::Fold(Val* v1, Val* v2) const { return index_slice(vect, lv); } break; - case TYPE_TABLE: v = v1->AsTableVal()->FindOrDefault({NewRef{}, v2}); break; + case TYPE_TABLE: + if ( is_pattern_table ) + return v1->AsTableVal()->LookupPattern(v2->AsListVal()->Idx(0)->AsStringVal()); + + v = v1->AsTableVal()->FindOrDefault({NewRef{}, v2}); + break; case TYPE_STRING: return index_string(v1->AsString(), v2->AsListVal()); diff --git a/src/Expr.h b/src/Expr.h index 3bec879dbb..a45a1528b8 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -1012,6 +1012,7 @@ protected: bool is_slice; bool is_inside_when; + bool is_pattern_table = false; }; // The following execute the heart of IndexExpr functionality for diff --git a/src/RE.cc b/src/RE.cc index 67144e9bd4..87b69d29a2 100644 --- a/src/RE.cc +++ b/src/RE.cc @@ -251,6 +251,33 @@ int Specific_RE_Matcher::Match(const u_char* bv, int n) { return 0; } +void Specific_RE_Matcher::MatchDisjunction(const String* s, std::vector& matches) { + auto bv = s->Bytes(); + auto n = s->Len(); + + ASSERT(dfa); + + DFA_State* d = dfa->StartState(); + d = d->Xtion(ecs[SYM_BOL], dfa); + + while ( d ) { + if ( --n < 0 ) + break; + + int ec = ecs[*(bv++)]; + d = d->Xtion(ec, dfa); + } + + if ( d ) + d = d->Xtion(ecs[SYM_EOL], dfa); + + if ( d ) + if ( auto a_set = d->Accept() ) + for ( auto a : *a_set ) + matches.push_back(a); +} + + void Specific_RE_Matcher::Dump(FILE* f) { dfa->Dump(f); } inline void RE_Match_State::AddMatches(const AcceptingSet& as, MatchPos position) { @@ -425,6 +452,23 @@ void RE_Matcher::MakeSingleLine() { bool RE_Matcher::Compile(bool lazy) { return re_anywhere->Compile(lazy) && re_exact->Compile(lazy); } +RE_DisjunctiveMatcher::RE_DisjunctiveMatcher(const std::vector& REs) { + matcher = std::make_unique(detail::MATCH_EXACTLY); + + std::string disjunction; + for ( auto re : REs ) + disjunction += std::string("||") + re->PatternText(); + + matcher->SetPat(disjunction.c_str()); + auto status = matcher->Compile(); + ASSERT(status); +} + +void RE_DisjunctiveMatcher::Match(const String* s, std::vector& matches) { + matches.clear(); + return matcher->MatchDisjunction(s, matches); +} + TEST_SUITE("re_matcher") { TEST_CASE("simple_pattern") { RE_Matcher match("[0-9]+"); diff --git a/src/RE.h b/src/RE.h index f68f3482bb..8d7b28da30 100644 --- a/src/RE.h +++ b/src/RE.h @@ -36,6 +36,7 @@ extern CCL* curr_ccl; extern NFA_Machine* nfa; extern Specific_RE_Matcher* rem; extern const char* RE_parse_input; +extern int RE_accept_num; extern int clower(int); extern void synerr(const char str[]); @@ -104,6 +105,17 @@ public: int Match(const String* s); int Match(const u_char* bv, int n); + // A disjunction is a collection of regular expressions (that under + // the hood are matches as a single RE, not serially) for which + // the match operation returns *all* of the matches. Disjunctions + // are constructed using the internal "||" RE operator, and the + // matches are returned as indices into the position, left-to-right, + // of which REs matched. IMPORTANT: the first RE is numbered 1, not 0. + // + // Note that there's no guarantee regarding the ordering of the + // returned matches if there is more than one. + void MatchDisjunction(const String* s, std::vector& matches); + int LongestMatch(const char* s); int LongestMatch(const String* s); int LongestMatch(const u_char* bv, int n, bool bol = true, bool eol = true); @@ -244,4 +256,17 @@ protected: bool is_single_line = false; }; +class RE_DisjunctiveMatcher final { +public: + // Takes a collection of individual REs and builds a disjunctive + // matcher for the set. + RE_DisjunctiveMatcher(const std::vector& REs); + + // See MatchDisjunction() above. + void Match(const String* s, std::vector& matches); + +private: + std::unique_ptr matcher; +}; + } // namespace zeek diff --git a/src/Val.cc b/src/Val.cc index 397c1d777c..d861d69386 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -1427,6 +1427,93 @@ static void find_nested_record_types(const TypePtr& t, std::set* fo } } +using PatternValPtr = IntrusivePtr; + +// Support class for returning multiple values from a table[pattern] +// when indexed with a string. +class TablePatternMatcher { +public: + TablePatternMatcher(const TableVal* _tbl, TypePtr _yield) : tbl(_tbl) { + vtype = make_intrusive(std::move(_yield)); + } + ~TablePatternMatcher() { Clear(); } + + void Insert(ValPtr pat, ValPtr yield) { Clear(); } + void Remove(ValPtr pat) { Clear(); } + + void Clear() { + delete matcher; + matcher = nullptr; + } + + VectorValPtr Lookup(const StringVal* s); + +private: + void Build(); + + const TableVal* tbl; + VectorTypePtr vtype; + + // If matcher is nil then we know we need to build it. This gives + // us an easy way to cache matchers in the common case that these + // sorts of tables don't change their elements very often (indeed, + // they'll frequently be constructed just once), and also keeps us + // from having to re-build the matcher on every insert/delete in + // the common case that a whole bunch of those are done in a single + // batch. + RE_DisjunctiveMatcher* matcher = nullptr; + + // Maps matcher values to corresponding yields. When building the + // matcher we insert a nil at the head to accommodate how + // disjunctive matchers use numbering starting at 1 rather than 0. + std::vector matcher_yields; +}; + +VectorValPtr TablePatternMatcher::Lookup(const StringVal* s) { + auto results = make_intrusive(vtype); + + if ( ! matcher ) { + if ( tbl->Get()->Length() == 0 ) + return results; + + Build(); + } + + std::vector matches; + matcher->Match(s->AsString(), matches); + + for ( auto m : matches ) + results->Append(matcher_yields[m]); + + return results; +} + +void TablePatternMatcher::Build() { + matcher_yields.clear(); + matcher_yields.push_back(nullptr); + + auto& tbl_dict = *tbl->Get(); + auto& tbl_hash = *tbl->GetTableHash(); + std::vector patterns; + + // We need to hold on to recovered hash key values so they don't + // get lost once a loop iteration goes out of scope. + std::vector hash_key_vals; + + for ( auto& iter : tbl_dict ) { + auto k = iter.GetHashKey(); + auto v = iter.value; + auto vl = tbl_hash.RecoverVals(*k); + + patterns.push_back(vl->AsListVal()->Idx(0)->AsPattern()); + matcher_yields.push_back(v->GetVal()); + + hash_key_vals.push_back(std::move(vl)); + } + + matcher = new RE_DisjunctiveMatcher(patterns); +} + TableVal::TableVal(TableTypePtr t, detail::AttributesPtr a) : Val(t) { bool ordered = (a != nullptr && a->Find(detail::ATTR_ORDERED) != nullptr); Init(std::move(t), ordered); @@ -1460,6 +1547,10 @@ void TableVal::Init(TableTypePtr t, bool ordered) { else subnets = nullptr; + auto& it = table_type->GetIndexTypes(); + if ( it.size() == 1 && it[0]->Tag() == TYPE_PATTERN && table_type->Yield() ) + pattern_matcher = new TablePatternMatcher(this, table_type->Yield()); + table_hash = new detail::CompositeHash(table_type->GetIndices()); if ( ordered ) table_val = new PDict(DictOrder::ORDERED); @@ -1476,6 +1567,7 @@ TableVal::~TableVal() { delete table_hash; delete table_val; delete subnets; + delete pattern_matcher; delete expire_iterator; } @@ -1486,6 +1578,9 @@ void TableVal::RemoveAll() { delete table_val; table_val = new PDict; table_val->SetDeleteFunc(table_entry_val_delete_func); + + if ( pattern_matcher ) + pattern_matcher->Clear(); } int TableVal::Size() const { return table_val->Length(); } @@ -1570,6 +1665,9 @@ bool TableVal::Assign(ValPtr index, ValPtr new_val, bool broker_forward, bool* i return false; } + if ( pattern_matcher ) + pattern_matcher->Insert(index->AsListVal()->Idx(0), new_val); + return Assign(std::move(index), std::move(k), std::move(new_val), broker_forward, iterators_invalidated); } @@ -1925,6 +2023,13 @@ TableValPtr TableVal::LookupSubnetValues(const SubNetVal* search) { return nt; } +VectorValPtr TableVal::LookupPattern(const StringVal* s) { + if ( ! pattern_matcher ) + reporter->InternalError("LookupPattern called on wrong table type"); + + return pattern_matcher->Lookup(s); +} + bool TableVal::UpdateTimestamp(Val* index) { TableEntryVal* v; @@ -2105,8 +2210,14 @@ ValPtr TableVal::Remove(const Val& index, bool broker_forward, bool* iterators_i va = v->GetVal() ? v->GetVal() : IntrusivePtr{NewRef{}, this}; if ( subnets && ! subnets->Remove(&index) ) + // VP: not clear to me this should be an internal warning, + // since Zeek doesn't otherwise complain about removing + // non-existent table elements. reporter->InternalWarning("index not in prefix table"); + if ( pattern_matcher ) + pattern_matcher->Remove(index.AsListVal()->Idx(0)); + delete v; Modified(); diff --git a/src/Val.h b/src/Val.h index 9ddeec842f..e0c67e91f8 100644 --- a/src/Val.h +++ b/src/Val.h @@ -718,6 +718,8 @@ protected: TableVal* table; }; +class TablePatternMatcher; + class TableVal final : public Val, public notifier::detail::Modifiable { public: explicit TableVal(TableTypePtr t, detail::AttributesPtr attrs = nullptr); @@ -863,6 +865,11 @@ public: // Causes an internal error if called for any other kind of table. TableValPtr LookupSubnetValues(const SubNetVal* s); + // For a table[pattern], return a vector of all yields matching + // the given string. + // Causes an internal error if called for any other kind of table. + VectorValPtr LookupPattern(const StringVal* s); + // Sets the timestamp for the given index to network time. // Returns false if index does not exist. bool UpdateTimestamp(Val* index); @@ -1032,6 +1039,7 @@ protected: TableValTimer* timer; RobustDictIterator* expire_iterator; detail::PrefixTable* subnets; + TablePatternMatcher* pattern_matcher = nullptr; ValPtr def_val; detail::ExprPtr change_func; std::string broker_store; diff --git a/src/re-parse.y b/src/re-parse.y index 2d6672df8d..2ee6bec9e2 100644 --- a/src/re-parse.y +++ b/src/re-parse.y @@ -21,6 +21,7 @@ void yyerror(const char msg[]); %} %token TOK_CHAR TOK_NUMBER TOK_CCL TOK_CCE TOK_CASE_INSENSITIVE TOK_SINGLE_LINE +%token TOK_DISJUNCTION %union { int int_val; @@ -32,7 +33,7 @@ void yyerror(const char msg[]); %type TOK_CHAR TOK_NUMBER %type TOK_CCE %type TOK_CCL ccl full_ccl -%type re singleton series string +%type re singleton series string disjunction %destructor { delete $$; } @@ -40,6 +41,9 @@ void yyerror(const char msg[]); flexrule : re { $1->AddAccept(1); zeek::detail::nfa = $1; } + | disjunction + { zeek::detail::nfa = $1; } + | error { return 1; } ; @@ -51,6 +55,18 @@ re : re '|' series { $$ = new zeek::detail::NFA_Machine(new zeek::detail::EpsilonState()); } ; +disjunction : disjunction TOK_DISJUNCTION re + { + $3->AddAccept(++zeek::detail::RE_accept_num); + $$ = zeek::detail::make_alternate($1, $3); + } + | TOK_DISJUNCTION re + { + $2->AddAccept(++zeek::detail::RE_accept_num); + $$ = $2; + } + ; + series : series singleton { $1->AppendMachine($2); $$ = $1; } | singleton diff --git a/src/re-scan.l b/src/re-scan.l index f382393477..7df4665640 100644 --- a/src/re-scan.l +++ b/src/re-scan.l @@ -23,6 +23,7 @@ #include "re-parse.h" const char* zeek::detail::RE_parse_input = nullptr; +int zeek::detail::RE_accept_num = 0; #define RET_CCE(func) \ BEGIN(SC_CCL); \ @@ -143,6 +144,8 @@ CCL_EXPR ("[:"[[:alpha:]]+":]") } } + "||" return TOK_DISJUNCTION; + [|*+?.(){}] return yytext[0]; . yylval.int_val = yytext[0]; return TOK_CHAR; \n return 0; // treat as end of pattern @@ -237,6 +240,7 @@ YY_BUFFER_STATE RE_buf; void RE_set_input(const char* str) { zeek::detail::RE_parse_input = str; + zeek::detail::RE_accept_num = 0; RE_buf = yy_scan_string(str); }