diff --git a/src/Type.cc b/src/Type.cc index 531fea4e5b..bd8a0651f7 100644 --- a/src/Type.cc +++ b/src/Type.cc @@ -858,8 +858,16 @@ const char* RecordType::AddFields(type_decl_list* others, attr_list* attr) { if ( ! td->FindAttr(ATTR_DEFAULT) && ! td->FindAttr(ATTR_OPTIONAL) ) + { + delete others; return "extension field must be &optional or have &default"; + } + } + TableVal::SaveParseTimeTableState(this); + + for ( const auto& td : *others ) + { if ( log ) { if ( ! td->attrs ) @@ -875,6 +883,7 @@ const char* RecordType::AddFields(type_decl_list* others, attr_list* attr) num_fields = types->length(); RecordVal::ResizeParseTimeRecords(this); + TableVal::RebuildParseTimeTables(); return 0; } diff --git a/src/Val.cc b/src/Val.cc index aa46f971f3..758461382c 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -14,6 +14,7 @@ #include #include +#include #include "Attr.h" #include "BroString.h" @@ -1310,10 +1311,62 @@ static void table_entry_val_delete_func(void* val) delete tv; } +static void find_nested_record_types(BroType* t, std::set* found) + { + if ( ! t ) + return; + + switch ( t->Tag() ) { + case TYPE_RECORD: + { + auto rt = t->AsRecordType(); + found->emplace(rt); + + for ( auto i = 0; i < rt->NumFields(); ++i ) + find_nested_record_types(rt->FieldDecl(i)->type, found); + } + return; + case TYPE_TABLE: + find_nested_record_types(t->AsTableType()->Indices(), found); + find_nested_record_types(t->AsTableType()->YieldType(), found); + return; + case TYPE_LIST: + { + for ( auto& t : *t->AsTypeList()->Types() ) + find_nested_record_types(t, found); + } + return; + case TYPE_FUNC: + find_nested_record_types(t->AsFuncType()->Args(), found); + find_nested_record_types(t->AsFuncType()->YieldType(), found); + return; + case TYPE_VECTOR: + find_nested_record_types(t->AsVectorType()->YieldType(), found); + return; + case TYPE_TYPE: + find_nested_record_types(t->AsTypeType()->Type(), found); + return; + default: + return; + } + } + TableVal::TableVal(TableType* t, Attributes* a) : Val(t) { Init(t); SetAttrs(a); + + if ( ! is_parsing ) + return; + + for ( const auto& t : *t->IndexTypes() ) + { + std::set found; + find_nested_record_types(t, &found); + + for ( auto rt : found ) + parse_time_table_record_dependencies[rt].emplace_back(NewRef{}, this); + } } void TableVal::Init(TableType* t) @@ -2548,6 +2601,68 @@ HashKey* TableVal::ComputeHash(const Val* index) const return table_hash->ComputeHash(index, 1); } +void TableVal::SaveParseTimeTableState(RecordType* rt) + { + auto it = parse_time_table_record_dependencies.find(rt); + + if ( it == parse_time_table_record_dependencies.end() ) + return; + + auto& table_vals = it->second; + + for ( auto& tv : table_vals ) + parse_time_table_states[tv.get()] = tv->DumpTableState(); + } + +void TableVal::RebuildParseTimeTables() + { + for ( auto& [tv, ptts] : parse_time_table_states ) + tv->RebuildTable(std::move(ptts)); + + parse_time_table_states.clear(); + } + +void TableVal::DoneParsing() + { + parse_time_table_record_dependencies.clear(); + } + +TableVal::ParseTimeTableState TableVal::DumpTableState() + { + const PDict* tbl = AsTable(); + IterCookie* cookie = tbl->InitForIteration(); + + HashKey* key; + TableEntryVal* val; + + ParseTimeTableState rval; + + while ( (val = tbl->NextEntry(key, cookie)) ) + { + rval.emplace_back(IntrusivePtr{AdoptRef{}, RecoverIndex(key)}, + IntrusivePtr{NewRef{}, val->Value()}); + + delete key; + } + + RemoveAll(); + return rval; + } + +void TableVal::RebuildTable(ParseTimeTableState ptts) + { + delete table_hash; + table_hash = new CompositeHash(IntrusivePtr(NewRef{}, + table_type->Indices())); + + for ( auto& [key, val] : ptts ) + Assign(key.get(), val.release()); + } + +TableVal::ParseTimeTableStates TableVal::parse_time_table_states; + +TableVal::TableRecordDependencies TableVal::parse_time_table_record_dependencies; + RecordVal::RecordTypeValMap RecordVal::parse_time_records; RecordVal::RecordVal(RecordType* t, bool init_fields) : Val(t) diff --git a/src/Val.h b/src/Val.h index 0bb588d310..467c6e542f 100644 --- a/src/Val.h +++ b/src/Val.h @@ -823,7 +823,29 @@ public: notifier::Modifiable* Modifiable() override { return this; } + // Retrieves and saves all table state (key-value pairs) for + // tables whose index type depends on the given RecordType. + static void SaveParseTimeTableState(RecordType* rt); + + // Rebuilds all TableVals whose state was previously saved by + // SaveParseTimeTableState(). This is used to re-recreate the tables + // in the event that a record type gets redefined while parsing. + static void RebuildParseTimeTables(); + + // Clears all state that was used to track TableVals that depending + // on RecordTypes. + static void DoneParsing(); + protected: + + using TableRecordDependencies = std::unordered_map>>; + + using ParseTimeTableState = std::vector, IntrusivePtr>>; + using ParseTimeTableStates = std::unordered_map; + + ParseTimeTableState DumpTableState(); + void RebuildTable(ParseTimeTableState ptts); + void Init(TableType* t); void CheckExpireAttr(attr_tag at); @@ -865,6 +887,9 @@ protected: Expr* change_func = nullptr; // prevent recursion of change functions bool in_change_func = false; + + static TableRecordDependencies parse_time_table_record_dependencies; + static ParseTimeTableStates parse_time_table_states; }; class RecordVal : public Val, public notifier::Modifiable { diff --git a/src/main.cc b/src/main.cc index 09171ccb63..0c4e22e28a 100644 --- a/src/main.cc +++ b/src/main.cc @@ -650,6 +650,7 @@ int main(int argc, char** argv) is_parsing = false; RecordVal::DoneParsing(); + TableVal::DoneParsing(); init_general_global_var(); init_net_var(); diff --git a/testing/btest/Baseline/language.table-record-idx-redef/out b/testing/btest/Baseline/language.table-record-idx-redef/out new file mode 100644 index 0000000000..3f935f9fd8 --- /dev/null +++ b/testing/btest/Baseline/language.table-record-idx-redef/out @@ -0,0 +1,5 @@ +F, T, F +{ +[[r=[rr=101, rrr=], a=37, b=]] = 1, +[[r=[rr=101, rrr=blue pill], a=13, b=28]] = 42 +} diff --git a/testing/btest/language/table-record-idx-redef.zeek b/testing/btest/language/table-record-idx-redef.zeek new file mode 100644 index 0000000000..2fa2732dd4 --- /dev/null +++ b/testing/btest/language/table-record-idx-redef.zeek @@ -0,0 +1,33 @@ +# @TEST-EXEC: zeek -b %INPUT >out +# @TEST-EXEC: btest-diff out + +type recrec: record { + rr: count &default = 101; +}; + +type myrec: record { + r: recrec &default=recrec(); + a: count &default=13; +}; + +global mr = myrec($a = 37); +global active: table[myrec] of count = table([mr] = 1); + +redef record myrec += { + b: count &default=28; +}; + +redef record recrec += { + rrr: string &default="blue pill"; +}; + +global check1: bool = myrec() in active; +global check2: bool = mr in active; +global check3: bool = myrec($a=37, $b=0) in active; + +event zeek_init() + { + print check1, check2, check3; + active[myrec()] = 42; + print active; + }