diff --git a/src/Attr.cc b/src/Attr.cc index da28845fe9..7e80736a31 100644 --- a/src/Attr.cc +++ b/src/Attr.cc @@ -166,6 +166,21 @@ void Attr::AddTag(ODesc* d) const d->Add(attr_name(Tag())); } +detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreAttr(this); + HANDLE_TC_ATTR_PRE(tc); + + if ( expr ) + { + auto tc = expr->Traverse(cb); + HANDLE_TC_ATTR_PRE(tc); + } + + tc = cb->PostAttr(this); + HANDLE_TC_ATTR_POST(tc); + } + Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global) : Attributes(std::vector{}, std::move(t), arg_in_record, is_global) { @@ -756,4 +771,19 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r return false; } +detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreAttrs(this); + HANDLE_TC_ATTRS_PRE(tc); + + for ( const auto& a : attrs ) + { + tc = a->Traverse(cb); + HANDLE_TC_ATTRS_PRE(tc); + } + + tc = cb->PostAttrs(this); + HANDLE_TC_ATTRS_POST(tc); + } + } diff --git a/src/Attr.h b/src/Attr.h index 6bc37ea726..9c1ec16f1b 100644 --- a/src/Attr.h +++ b/src/Attr.h @@ -7,6 +7,7 @@ #include "zeek/IntrusivePtr.h" #include "zeek/Obj.h" +#include "zeek/Traverse.h" #include "zeek/ZeekList.h" // Note that there are two kinds of attributes: the kind (here) which @@ -98,6 +99,8 @@ public: return true; } + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; + protected: void AddTag(ODesc* d) const; @@ -129,6 +132,8 @@ public: bool operator==(const Attributes& other) const; + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; + protected: void CheckAttr(Attr* attr); diff --git a/src/Expr.cc b/src/Expr.cc index 052f276863..4cc3f9d572 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -3716,6 +3716,24 @@ TableConstructorExpr::TableConstructorExpr(ListExprPtr constructor_list, } } +TraversalCode TableConstructorExpr::Traverse(TraversalCallback* cb) const + { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = op->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + if ( attrs ) + { + tc = attrs->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + } + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); + } + ValPtr TableConstructorExpr::Eval(Frame* f) const { if ( IsError() ) @@ -3834,6 +3852,24 @@ SetConstructorExpr::SetConstructorExpr(ListExprPtr constructor_list, } } +TraversalCode SetConstructorExpr::Traverse(TraversalCallback* cb) const + { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = op->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + if ( attrs ) + { + tc = attrs->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + } + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); + } + ValPtr SetConstructorExpr::Eval(Frame* f) const { if ( IsError() ) @@ -4700,15 +4736,15 @@ LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList ar } // Install that in the global_scope - auto id = install_ID(my_name.c_str(), current_module.c_str(), true, false); + lambda_id = install_ID(my_name.c_str(), current_module.c_str(), true, false); // Update lamb's name dummy_func->SetName(my_name.c_str()); auto v = make_intrusive(std::move(dummy_func)); - id->SetVal(std::move(v)); - id->SetType(ingredients->id->GetType()); - id->SetConst(); + lambda_id->SetVal(std::move(v)); + lambda_id->SetType(ingredients->id->GetType()); + lambda_id->SetConst(); } void LambdaExpr::CheckCaptures(StmtPtr when_parent) @@ -4823,8 +4859,11 @@ TraversalCode LambdaExpr::Traverse(TraversalCallback* cb) const TraversalCode tc = cb->PreExpr(this); HANDLE_TC_EXPR_PRE(tc); + tc = lambda_id->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + tc = ingredients->body->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); + HANDLE_TC_EXPR_PRE(tc); tc = cb->PostExpr(this); HANDLE_TC_EXPR_POST(tc); @@ -4895,6 +4934,22 @@ TraversalCode EventExpr::Traverse(TraversalCallback* cb) const TraversalCode tc = cb->PreExpr(this); HANDLE_TC_EXPR_PRE(tc); + auto& f = handler->GetFunc(); + if ( f ) + { + // We don't traverse the function, because that can lead + // to infinite traversals. We do, however, see if we can + // locate the corresponding identifier, and traverse that. + + auto& id = lookup_ID(f->Name(), GLOBAL_MODULE_NAME, false, false, false); + + if ( id ) + { + tc = id->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + } + } + tc = args->Traverse(cb); HANDLE_TC_EXPR_PRE(tc); diff --git a/src/Expr.h b/src/Expr.h index d000d88538..6341a659ad 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -1208,6 +1208,8 @@ public: ValPtr Eval(Frame* f) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; + // Optimization-related: ExprPtr Duplicate() override; @@ -1232,6 +1234,8 @@ public: ValPtr Eval(Frame* f) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; + // Optimization-related: ExprPtr Duplicate() override; @@ -1479,6 +1483,7 @@ private: void CheckCaptures(StmtPtr when_parent); std::unique_ptr ingredients; + IDPtr lambda_id; IDPList outer_ids; bool capture_by_ref; // if true, use deprecated reference semantics diff --git a/src/Stmt.cc b/src/Stmt.cc index 0c1265fbfb..a17d59b2de 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -2162,22 +2162,27 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const TraversalCode tc = cb->PreStmt(this); HANDLE_TC_STMT_PRE(tc); - tc = wi->Cond()->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); + auto wl = wi->Lambda(); - tc = wi->WhenBody()->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - - if ( wi->TimeoutExpr() ) + if ( wl ) { - tc = wi->TimeoutExpr()->Traverse(cb); + tc = wl->Traverse(cb); HANDLE_TC_STMT_PRE(tc); } - if ( wi->TimeoutStmt() ) + else { - tc = wi->TimeoutStmt()->Traverse(cb); + tc = wi->Cond()->Traverse(cb); HANDLE_TC_STMT_PRE(tc); + + tc = wi->WhenBody()->Traverse(cb); + HANDLE_TC_STMT_PRE(tc); + + if ( wi->TimeoutStmt() ) + { + tc = wi->TimeoutStmt()->Traverse(cb); + HANDLE_TC_STMT_PRE(tc); + } } tc = cb->PostStmt(this); diff --git a/src/Traverse.h b/src/Traverse.h index 6d359c177d..0e49d8b000 100644 --- a/src/Traverse.h +++ b/src/Traverse.h @@ -9,6 +9,7 @@ namespace zeek { class Func; +class Type; namespace detail { @@ -16,6 +17,8 @@ namespace detail class Stmt; class Expr; class ID; +class Attributes; +class Attr; class TraversalCallback { @@ -41,6 +44,20 @@ public: virtual TraversalCode PreDecl(const ID*) { return TC_CONTINUE; } virtual TraversalCode PostDecl(const ID*) { return TC_CONTINUE; } + // A caution regarding using the next two: when traversing types, + // there's a possibility of encountering a (directly or indirectly) + // recursive record. So you'll need some way of avoiding that, + // such as remembering which types have already been traversed + // and skipping via TC_ABORTSTMT when seen again. + virtual TraversalCode PreType(const Type*) { return TC_CONTINUE; } + virtual TraversalCode PostType(const Type*) { return TC_CONTINUE; } + + virtual TraversalCode PreAttrs(const Attributes*) { return TC_CONTINUE; } + virtual TraversalCode PostAttrs(const Attributes*) { return TC_CONTINUE; } + + virtual TraversalCode PreAttr(const Attr*) { return TC_CONTINUE; } + virtual TraversalCode PostAttr(const Attr*) { return TC_CONTINUE; } + ScopePtr current_scope; }; diff --git a/src/TraverseTypes.h b/src/TraverseTypes.h index 98d1e1b66f..f1b0ed69be 100644 --- a/src/TraverseTypes.h +++ b/src/TraverseTypes.h @@ -34,14 +34,16 @@ enum TraversalCode return (code); \ } -#define HANDLE_TC_EXPR_PRE(code) \ - { \ - if ( (code) == zeek::detail::TC_ABORTALL ) \ - return (code); \ - else if ( (code) == zeek::detail::TC_ABORTSTMT ) \ - return zeek::detail::TC_CONTINUE; \ - } - +#define HANDLE_TC_EXPR_PRE(code) HANDLE_TC_STMT_PRE(code) #define HANDLE_TC_EXPR_POST(code) return (code); +#define HANDLE_TC_TYPE_PRE(code) HANDLE_TC_STMT_PRE(code) +#define HANDLE_TC_TYPE_POST(code) return (code); + +#define HANDLE_TC_ATTRS_PRE(code) HANDLE_TC_STMT_PRE(code) +#define HANDLE_TC_ATTRS_POST(code) return (code); + +#define HANDLE_TC_ATTR_PRE(code) HANDLE_TC_STMT_PRE(code) +#define HANDLE_TC_ATTR_POST(code) return (code); + } // namespace zeek::detail diff --git a/src/Type.cc b/src/Type.cc index 70128b94f9..0196644288 100644 --- a/src/Type.cc +++ b/src/Type.cc @@ -274,6 +274,15 @@ unsigned int Type::MemoryAllocation() const return padded_sizeof(*this); } +detail::TraversalCode Type::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + bool TypeList::AllMatch(const Type* t, bool is_init) const { for ( const auto& type : types ) @@ -340,6 +349,21 @@ unsigned int TypeList::MemoryAllocation() const #pragma GCC diagnostic pop } +detail::TraversalCode TypeList::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + for ( const auto& type : types ) + { + tc = type->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + } + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + int IndexType::MatchesIndex(detail::ListExpr* const index) const { // If we have a type indexed by subnets, addresses are ok. @@ -435,6 +459,27 @@ bool IndexType::IsSubNetIndex() const return false; } +detail::TraversalCode IndexType::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + for ( const auto& ind : GetIndexTypes() ) + { + tc = ind->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + } + + if ( yield_type ) + { + tc = yield_type->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + } + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + static bool is_supported_index_type(const TypePtr& t, const char** tname) { if ( t->InternalType() != TYPE_INTERNAL_OTHER ) @@ -865,6 +910,36 @@ std::optional FuncType::FindPrototype(const RecordType& arg return {}; } +detail::TraversalCode FuncType::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + tc = args->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + + if ( yield ) + { + tc = yield->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + } + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + +detail::TraversalCode TypeType::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + tc = type->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + TypeDecl::TypeDecl(const char* i, TypePtr t, detail::AttributesPtr arg_attrs) : type(std::move(t)), attrs(std::move(arg_attrs)), id(i) { @@ -1458,6 +1533,28 @@ string RecordType::GetFieldDeprecationWarning(int field, bool has_check) const return ""; } +detail::TraversalCode RecordType::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + if ( types ) + for ( const auto& td : *types ) + { + tc = td->type->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + + if ( td->attrs ) + { + tc = td->attrs->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + } + } + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + FileType::FileType(TypePtr yield_type) : Type(TYPE_FILE), yield(std::move(yield_type)) { } FileType::~FileType() = default; @@ -1476,6 +1573,18 @@ void FileType::DoDescribe(ODesc* d) const } } +detail::TraversalCode FileType::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + tc = yield->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + OpaqueType::OpaqueType(const string& arg_name) : Type(TYPE_OPAQUE) { name = arg_name; @@ -1832,6 +1941,18 @@ void VectorType::DescribeReST(ODesc* d, bool roles_only) const d->Add(util::fmt(":zeek:type:`%s`", yield_type->GetName().c_str())); } +detail::TraversalCode VectorType::Traverse(detail::TraversalCallback* cb) const + { + auto tc = cb->PreType(this); + HANDLE_TC_TYPE_PRE(tc); + + tc = yield_type->Traverse(cb); + HANDLE_TC_TYPE_PRE(tc); + + tc = cb->PostType(this); + HANDLE_TC_TYPE_POST(tc); + } + // Returns true if t1 is initialization-compatible with t2 (i.e., if an // initializer with type t1 can be used to initialize a value with type t2), // false otherwise. Assumes that t1's tag is different from t2's. Note diff --git a/src/Type.h b/src/Type.h index e184a66c8a..7e9b3a1d29 100644 --- a/src/Type.h +++ b/src/Type.h @@ -13,6 +13,7 @@ #include "zeek/ID.h" #include "zeek/IntrusivePtr.h" #include "zeek/Obj.h" +#include "zeek/Traverse.h" #include "zeek/ZeekList.h" namespace zeek @@ -258,6 +259,8 @@ public: void SetName(const std::string& arg_name) { name = arg_name; } const std::string& GetName() const { return name; } + virtual detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; + struct TypePtrComparer { bool operator()(const TypePtr& a, const TypePtr& b) const { return a.get() < b.get(); } @@ -353,6 +356,8 @@ public: "GHI-572.")]] unsigned int MemoryAllocation() const override; + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override; + protected: void DoDescribe(ODesc* d) const override; @@ -376,6 +381,8 @@ public: // Returns true if this table is solely indexed by subnet. bool IsSubNetIndex() const; + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override; + protected: IndexType(TypeTag t, TypeListPtr arg_indices, TypePtr arg_yield_type) : Type(t), indices(std::move(arg_indices)), yield_type(std::move(arg_yield_type)) @@ -533,6 +540,8 @@ public: */ void SetExpressionlessReturnOkay(bool is_ok) { expressionless_return_okay = is_ok; } + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override; + protected: friend FuncTypePtr make_intrusive(); @@ -564,6 +573,8 @@ public: template IntrusivePtr GetType() const { return cast_intrusive(type); } + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override; + protected: TypePtr type; }; @@ -698,6 +709,8 @@ public: std::string GetFieldDeprecationWarning(int field, bool has_check) const; + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override; + protected: RecordType() { types = nullptr; } @@ -731,6 +744,8 @@ public: const TypePtr& Yield() const override { return yield; } + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override; + protected: void DoDescribe(ODesc* d) const override; @@ -844,6 +859,8 @@ public: void DescribeReST(ODesc* d, bool roles_only = false) const override; + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override; + protected: void DoDescribe(ODesc* d) const override; diff --git a/testing/btest/Baseline/language.table-type-checking/out b/testing/btest/Baseline/language.table-type-checking/out index 0421c6d9f1..54a8b0534d 100644 --- a/testing/btest/Baseline/language.table-type-checking/out +++ b/testing/btest/Baseline/language.table-type-checking/out @@ -14,4 +14,4 @@ error in port and <...>/table-type-checking.zeek, line 42: type clash (port and error in <...>/table-type-checking.zeek, line 42: inconsistent types in table constructor (table(thousand-two = 1002)) error in <...>/table-type-checking.zeek, line 48: type clash in assignment (lea = table(thousand-three = 1003)) error in count and <...>/table-type-checking.zeek, line 54: arithmetic mixed with non-arithmetic (count and foo) -error in <...>/table-type-checking.zeek, line 4 and <...>/table-type-checking.zeek, line 54: &default value has inconsistent type (MyTable and table()) +error in <...>/table-type-checking.zeek, line 4 and <...>/table-type-checking.zeek, line 54: &default value has inconsistent type (MyTable and table()&default=foo, &optional)