broader support for AST traversal, including Attr and Attributes objects

This commit is contained in:
Vern Paxson 2022-05-04 17:07:18 -07:00 committed by Tim Wojtulewicz
parent 9a2200e60a
commit a0fc8ca5e4
10 changed files with 280 additions and 23 deletions

View file

@ -166,6 +166,21 @@ void Attr::AddTag(ODesc* d) const
d->Add(attr_name(Tag())); d->Add(attr_name(Tag()));
} }
detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreAttr(this);
HANDLE_TC_ATTR_PRE(tc);
if ( expr )
{
auto tc = expr->Traverse(cb);
HANDLE_TC_ATTR_PRE(tc);
}
tc = cb->PostAttr(this);
HANDLE_TC_ATTR_POST(tc);
}
Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global) Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global)
: Attributes(std::vector<AttrPtr>{}, std::move(t), arg_in_record, is_global) : Attributes(std::vector<AttrPtr>{}, std::move(t), arg_in_record, is_global)
{ {
@ -756,4 +771,19 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
return false; return false;
} }
detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreAttrs(this);
HANDLE_TC_ATTRS_PRE(tc);
for ( const auto& a : attrs )
{
tc = a->Traverse(cb);
HANDLE_TC_ATTRS_PRE(tc);
}
tc = cb->PostAttrs(this);
HANDLE_TC_ATTRS_POST(tc);
}
} }

View file

@ -7,6 +7,7 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/Traverse.h"
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
// Note that there are two kinds of attributes: the kind (here) which // Note that there are two kinds of attributes: the kind (here) which
@ -98,6 +99,8 @@ public:
return true; return true;
} }
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const;
protected: protected:
void AddTag(ODesc* d) const; void AddTag(ODesc* d) const;
@ -129,6 +132,8 @@ public:
bool operator==(const Attributes& other) const; bool operator==(const Attributes& other) const;
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const;
protected: protected:
void CheckAttr(Attr* attr); void CheckAttr(Attr* attr);

View file

@ -3716,6 +3716,24 @@ TableConstructorExpr::TableConstructorExpr(ListExprPtr constructor_list,
} }
} }
TraversalCode TableConstructorExpr::Traverse(TraversalCallback* cb) const
{
TraversalCode tc = cb->PreExpr(this);
HANDLE_TC_EXPR_PRE(tc);
tc = op->Traverse(cb);
HANDLE_TC_EXPR_PRE(tc);
if ( attrs )
{
tc = attrs->Traverse(cb);
HANDLE_TC_EXPR_PRE(tc);
}
tc = cb->PostExpr(this);
HANDLE_TC_EXPR_POST(tc);
}
ValPtr TableConstructorExpr::Eval(Frame* f) const ValPtr TableConstructorExpr::Eval(Frame* f) const
{ {
if ( IsError() ) if ( IsError() )
@ -3834,6 +3852,24 @@ SetConstructorExpr::SetConstructorExpr(ListExprPtr constructor_list,
} }
} }
TraversalCode SetConstructorExpr::Traverse(TraversalCallback* cb) const
{
TraversalCode tc = cb->PreExpr(this);
HANDLE_TC_EXPR_PRE(tc);
tc = op->Traverse(cb);
HANDLE_TC_EXPR_PRE(tc);
if ( attrs )
{
tc = attrs->Traverse(cb);
HANDLE_TC_EXPR_PRE(tc);
}
tc = cb->PostExpr(this);
HANDLE_TC_EXPR_POST(tc);
}
ValPtr SetConstructorExpr::Eval(Frame* f) const ValPtr SetConstructorExpr::Eval(Frame* f) const
{ {
if ( IsError() ) if ( IsError() )
@ -4700,15 +4736,15 @@ LambdaExpr::LambdaExpr(std::unique_ptr<function_ingredients> arg_ing, IDPList ar
} }
// Install that in the global_scope // Install that in the global_scope
auto id = install_ID(my_name.c_str(), current_module.c_str(), true, false); lambda_id = install_ID(my_name.c_str(), current_module.c_str(), true, false);
// Update lamb's name // Update lamb's name
dummy_func->SetName(my_name.c_str()); dummy_func->SetName(my_name.c_str());
auto v = make_intrusive<FuncVal>(std::move(dummy_func)); auto v = make_intrusive<FuncVal>(std::move(dummy_func));
id->SetVal(std::move(v)); lambda_id->SetVal(std::move(v));
id->SetType(ingredients->id->GetType()); lambda_id->SetType(ingredients->id->GetType());
id->SetConst(); lambda_id->SetConst();
} }
void LambdaExpr::CheckCaptures(StmtPtr when_parent) void LambdaExpr::CheckCaptures(StmtPtr when_parent)
@ -4823,8 +4859,11 @@ TraversalCode LambdaExpr::Traverse(TraversalCallback* cb) const
TraversalCode tc = cb->PreExpr(this); TraversalCode tc = cb->PreExpr(this);
HANDLE_TC_EXPR_PRE(tc); HANDLE_TC_EXPR_PRE(tc);
tc = lambda_id->Traverse(cb);
HANDLE_TC_EXPR_PRE(tc);
tc = ingredients->body->Traverse(cb); tc = ingredients->body->Traverse(cb);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_EXPR_PRE(tc);
tc = cb->PostExpr(this); tc = cb->PostExpr(this);
HANDLE_TC_EXPR_POST(tc); HANDLE_TC_EXPR_POST(tc);
@ -4895,6 +4934,22 @@ TraversalCode EventExpr::Traverse(TraversalCallback* cb) const
TraversalCode tc = cb->PreExpr(this); TraversalCode tc = cb->PreExpr(this);
HANDLE_TC_EXPR_PRE(tc); HANDLE_TC_EXPR_PRE(tc);
auto& f = handler->GetFunc();
if ( f )
{
// We don't traverse the function, because that can lead
// to infinite traversals. We do, however, see if we can
// locate the corresponding identifier, and traverse that.
auto& id = lookup_ID(f->Name(), GLOBAL_MODULE_NAME, false, false, false);
if ( id )
{
tc = id->Traverse(cb);
HANDLE_TC_EXPR_PRE(tc);
}
}
tc = args->Traverse(cb); tc = args->Traverse(cb);
HANDLE_TC_EXPR_PRE(tc); HANDLE_TC_EXPR_PRE(tc);

View file

@ -1208,6 +1208,8 @@ public:
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
TraversalCode Traverse(TraversalCallback* cb) const override;
// Optimization-related: // Optimization-related:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
@ -1232,6 +1234,8 @@ public:
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
TraversalCode Traverse(TraversalCallback* cb) const override;
// Optimization-related: // Optimization-related:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
@ -1479,6 +1483,7 @@ private:
void CheckCaptures(StmtPtr when_parent); void CheckCaptures(StmtPtr when_parent);
std::unique_ptr<function_ingredients> ingredients; std::unique_ptr<function_ingredients> ingredients;
IDPtr lambda_id;
IDPList outer_ids; IDPList outer_ids;
bool capture_by_ref; // if true, use deprecated reference semantics bool capture_by_ref; // if true, use deprecated reference semantics

View file

@ -2162,23 +2162,28 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const
TraversalCode tc = cb->PreStmt(this); TraversalCode tc = cb->PreStmt(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
auto wl = wi->Lambda();
if ( wl )
{
tc = wl->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
}
else
{
tc = wi->Cond()->Traverse(cb); tc = wi->Cond()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
tc = wi->WhenBody()->Traverse(cb); tc = wi->WhenBody()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
if ( wi->TimeoutExpr() )
{
tc = wi->TimeoutExpr()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
}
if ( wi->TimeoutStmt() ) if ( wi->TimeoutStmt() )
{ {
tc = wi->TimeoutStmt()->Traverse(cb); tc = wi->TimeoutStmt()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
} }
}
tc = cb->PostStmt(this); tc = cb->PostStmt(this);
HANDLE_TC_STMT_POST(tc); HANDLE_TC_STMT_POST(tc);

View file

@ -9,6 +9,7 @@ namespace zeek
{ {
class Func; class Func;
class Type;
namespace detail namespace detail
{ {
@ -16,6 +17,8 @@ namespace detail
class Stmt; class Stmt;
class Expr; class Expr;
class ID; class ID;
class Attributes;
class Attr;
class TraversalCallback class TraversalCallback
{ {
@ -41,6 +44,20 @@ public:
virtual TraversalCode PreDecl(const ID*) { return TC_CONTINUE; } virtual TraversalCode PreDecl(const ID*) { return TC_CONTINUE; }
virtual TraversalCode PostDecl(const ID*) { return TC_CONTINUE; } virtual TraversalCode PostDecl(const ID*) { return TC_CONTINUE; }
// A caution regarding using the next two: when traversing types,
// there's a possibility of encountering a (directly or indirectly)
// recursive record. So you'll need some way of avoiding that,
// such as remembering which types have already been traversed
// and skipping via TC_ABORTSTMT when seen again.
virtual TraversalCode PreType(const Type*) { return TC_CONTINUE; }
virtual TraversalCode PostType(const Type*) { return TC_CONTINUE; }
virtual TraversalCode PreAttrs(const Attributes*) { return TC_CONTINUE; }
virtual TraversalCode PostAttrs(const Attributes*) { return TC_CONTINUE; }
virtual TraversalCode PreAttr(const Attr*) { return TC_CONTINUE; }
virtual TraversalCode PostAttr(const Attr*) { return TC_CONTINUE; }
ScopePtr current_scope; ScopePtr current_scope;
}; };

View file

@ -34,14 +34,16 @@ enum TraversalCode
return (code); \ return (code); \
} }
#define HANDLE_TC_EXPR_PRE(code) \ #define HANDLE_TC_EXPR_PRE(code) HANDLE_TC_STMT_PRE(code)
{ \
if ( (code) == zeek::detail::TC_ABORTALL ) \
return (code); \
else if ( (code) == zeek::detail::TC_ABORTSTMT ) \
return zeek::detail::TC_CONTINUE; \
}
#define HANDLE_TC_EXPR_POST(code) return (code); #define HANDLE_TC_EXPR_POST(code) return (code);
#define HANDLE_TC_TYPE_PRE(code) HANDLE_TC_STMT_PRE(code)
#define HANDLE_TC_TYPE_POST(code) return (code);
#define HANDLE_TC_ATTRS_PRE(code) HANDLE_TC_STMT_PRE(code)
#define HANDLE_TC_ATTRS_POST(code) return (code);
#define HANDLE_TC_ATTR_PRE(code) HANDLE_TC_STMT_PRE(code)
#define HANDLE_TC_ATTR_POST(code) return (code);
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -274,6 +274,15 @@ unsigned int Type::MemoryAllocation() const
return padded_sizeof(*this); return padded_sizeof(*this);
} }
detail::TraversalCode Type::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
bool TypeList::AllMatch(const Type* t, bool is_init) const bool TypeList::AllMatch(const Type* t, bool is_init) const
{ {
for ( const auto& type : types ) for ( const auto& type : types )
@ -340,6 +349,21 @@ unsigned int TypeList::MemoryAllocation() const
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
} }
detail::TraversalCode TypeList::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
for ( const auto& type : types )
{
tc = type->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
}
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
int IndexType::MatchesIndex(detail::ListExpr* const index) const int IndexType::MatchesIndex(detail::ListExpr* const index) const
{ {
// If we have a type indexed by subnets, addresses are ok. // If we have a type indexed by subnets, addresses are ok.
@ -435,6 +459,27 @@ bool IndexType::IsSubNetIndex() const
return false; return false;
} }
detail::TraversalCode IndexType::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
for ( const auto& ind : GetIndexTypes() )
{
tc = ind->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
}
if ( yield_type )
{
tc = yield_type->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
}
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
static bool is_supported_index_type(const TypePtr& t, const char** tname) static bool is_supported_index_type(const TypePtr& t, const char** tname)
{ {
if ( t->InternalType() != TYPE_INTERNAL_OTHER ) if ( t->InternalType() != TYPE_INTERNAL_OTHER )
@ -865,6 +910,36 @@ std::optional<FuncType::Prototype> FuncType::FindPrototype(const RecordType& arg
return {}; return {};
} }
detail::TraversalCode FuncType::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
tc = args->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
if ( yield )
{
tc = yield->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
}
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
detail::TraversalCode TypeType::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
tc = type->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
TypeDecl::TypeDecl(const char* i, TypePtr t, detail::AttributesPtr arg_attrs) TypeDecl::TypeDecl(const char* i, TypePtr t, detail::AttributesPtr arg_attrs)
: type(std::move(t)), attrs(std::move(arg_attrs)), id(i) : type(std::move(t)), attrs(std::move(arg_attrs)), id(i)
{ {
@ -1458,6 +1533,28 @@ string RecordType::GetFieldDeprecationWarning(int field, bool has_check) const
return ""; return "";
} }
detail::TraversalCode RecordType::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
if ( types )
for ( const auto& td : *types )
{
tc = td->type->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
if ( td->attrs )
{
tc = td->attrs->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
}
}
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
FileType::FileType(TypePtr yield_type) : Type(TYPE_FILE), yield(std::move(yield_type)) { } FileType::FileType(TypePtr yield_type) : Type(TYPE_FILE), yield(std::move(yield_type)) { }
FileType::~FileType() = default; FileType::~FileType() = default;
@ -1476,6 +1573,18 @@ void FileType::DoDescribe(ODesc* d) const
} }
} }
detail::TraversalCode FileType::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
tc = yield->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
OpaqueType::OpaqueType(const string& arg_name) : Type(TYPE_OPAQUE) OpaqueType::OpaqueType(const string& arg_name) : Type(TYPE_OPAQUE)
{ {
name = arg_name; name = arg_name;
@ -1832,6 +1941,18 @@ void VectorType::DescribeReST(ODesc* d, bool roles_only) const
d->Add(util::fmt(":zeek:type:`%s`", yield_type->GetName().c_str())); d->Add(util::fmt(":zeek:type:`%s`", yield_type->GetName().c_str()));
} }
detail::TraversalCode VectorType::Traverse(detail::TraversalCallback* cb) const
{
auto tc = cb->PreType(this);
HANDLE_TC_TYPE_PRE(tc);
tc = yield_type->Traverse(cb);
HANDLE_TC_TYPE_PRE(tc);
tc = cb->PostType(this);
HANDLE_TC_TYPE_POST(tc);
}
// Returns true if t1 is initialization-compatible with t2 (i.e., if an // Returns true if t1 is initialization-compatible with t2 (i.e., if an
// initializer with type t1 can be used to initialize a value with type t2), // initializer with type t1 can be used to initialize a value with type t2),
// false otherwise. Assumes that t1's tag is different from t2's. Note // false otherwise. Assumes that t1's tag is different from t2's. Note

View file

@ -13,6 +13,7 @@
#include "zeek/ID.h" #include "zeek/ID.h"
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/Traverse.h"
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
namespace zeek namespace zeek
@ -258,6 +259,8 @@ public:
void SetName(const std::string& arg_name) { name = arg_name; } void SetName(const std::string& arg_name) { name = arg_name; }
const std::string& GetName() const { return name; } const std::string& GetName() const { return name; }
virtual detail::TraversalCode Traverse(detail::TraversalCallback* cb) const;
struct TypePtrComparer struct TypePtrComparer
{ {
bool operator()(const TypePtr& a, const TypePtr& b) const { return a.get() < b.get(); } bool operator()(const TypePtr& a, const TypePtr& b) const { return a.get() < b.get(); }
@ -353,6 +356,8 @@ public:
"GHI-572.")]] unsigned int "GHI-572.")]] unsigned int
MemoryAllocation() const override; MemoryAllocation() const override;
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override;
protected: protected:
void DoDescribe(ODesc* d) const override; void DoDescribe(ODesc* d) const override;
@ -376,6 +381,8 @@ public:
// Returns true if this table is solely indexed by subnet. // Returns true if this table is solely indexed by subnet.
bool IsSubNetIndex() const; bool IsSubNetIndex() const;
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override;
protected: protected:
IndexType(TypeTag t, TypeListPtr arg_indices, TypePtr arg_yield_type) IndexType(TypeTag t, TypeListPtr arg_indices, TypePtr arg_yield_type)
: Type(t), indices(std::move(arg_indices)), yield_type(std::move(arg_yield_type)) : Type(t), indices(std::move(arg_indices)), yield_type(std::move(arg_yield_type))
@ -533,6 +540,8 @@ public:
*/ */
void SetExpressionlessReturnOkay(bool is_ok) { expressionless_return_okay = is_ok; } void SetExpressionlessReturnOkay(bool is_ok) { expressionless_return_okay = is_ok; }
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override;
protected: protected:
friend FuncTypePtr make_intrusive<FuncType>(); friend FuncTypePtr make_intrusive<FuncType>();
@ -564,6 +573,8 @@ public:
template <class T> IntrusivePtr<T> GetType() const { return cast_intrusive<T>(type); } template <class T> IntrusivePtr<T> GetType() const { return cast_intrusive<T>(type); }
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override;
protected: protected:
TypePtr type; TypePtr type;
}; };
@ -698,6 +709,8 @@ public:
std::string GetFieldDeprecationWarning(int field, bool has_check) const; std::string GetFieldDeprecationWarning(int field, bool has_check) const;
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override;
protected: protected:
RecordType() { types = nullptr; } RecordType() { types = nullptr; }
@ -731,6 +744,8 @@ public:
const TypePtr& Yield() const override { return yield; } const TypePtr& Yield() const override { return yield; }
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override;
protected: protected:
void DoDescribe(ODesc* d) const override; void DoDescribe(ODesc* d) const override;
@ -844,6 +859,8 @@ public:
void DescribeReST(ODesc* d, bool roles_only = false) const override; void DescribeReST(ODesc* d, bool roles_only = false) const override;
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const override;
protected: protected:
void DoDescribe(ODesc* d) const override; void DoDescribe(ODesc* d) const override;

View file

@ -14,4 +14,4 @@ error in port and <...>/table-type-checking.zeek, line 42: type clash (port and
error in <...>/table-type-checking.zeek, line 42: inconsistent types in table constructor (table(thousand-two = 1002)) error in <...>/table-type-checking.zeek, line 42: inconsistent types in table constructor (table(thousand-two = 1002))
error in <...>/table-type-checking.zeek, line 48: type clash in assignment (lea = table(thousand-three = 1003)) error in <...>/table-type-checking.zeek, line 48: type clash in assignment (lea = table(thousand-three = 1003))
error in count and <...>/table-type-checking.zeek, line 54: arithmetic mixed with non-arithmetic (count and foo) error in count and <...>/table-type-checking.zeek, line 54: arithmetic mixed with non-arithmetic (count and foo)
error in <...>/table-type-checking.zeek, line 4 and <...>/table-type-checking.zeek, line 54: &default value has inconsistent type (MyTable and table()) error in <...>/table-type-checking.zeek, line 4 and <...>/table-type-checking.zeek, line 54: &default value has inconsistent type (MyTable and table()&default=foo, &optional)