captures for "when" statements

update Triggers to IntrusivePtr's and simpler AST traversal
introduce IDSet type, migrate associated "ID*" types to "const ID*"
This commit is contained in:
Vern Paxson 2022-01-07 14:50:35 -08:00
parent fa142438fe
commit f895008c34
24 changed files with 648 additions and 202 deletions

View file

@ -23,21 +23,19 @@ using namespace zeek::detail::trigger;
namespace zeek::detail::trigger
{
// Used to extract the globals and locals seen in a trigger expression.
class TriggerTraversalCallback : public TraversalCallback
{
public:
TriggerTraversalCallback(Trigger* arg_trigger)
TriggerTraversalCallback(IDSet& _globals, IDSet& _locals) : globals(_globals), locals(_locals)
{
Ref(arg_trigger);
trigger = arg_trigger;
}
~TriggerTraversalCallback() { Unref(trigger); }
virtual TraversalCode PreExpr(const Expr*) override;
private:
Trigger* trigger;
IDSet& globals;
IDSet& locals;
};
TraversalCode trigger::TriggerTraversalCallback::PreExpr(const Expr* expr)
@ -50,14 +48,12 @@ TraversalCode trigger::TriggerTraversalCallback::PreExpr(const Expr* expr)
case EXPR_NAME:
{
const auto* e = static_cast<const NameExpr*>(expr);
if ( e->Id()->IsGlobal() )
trigger->Register(e->Id());
auto id = e->Id();
Val* v = e->Id()->GetVal().get();
if ( v && v->Modifiable() )
trigger->Register(v);
break;
if ( id->IsGlobal() )
globals.insert(id);
else
locals.insert(id);
};
default:
@ -103,10 +99,35 @@ protected:
double time;
};
Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeout_expr,
Trigger::Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, ExprPtr timeout_expr,
Frame* frame, bool is_return, const Location* location)
{
timeout_value = -1;
GetTimeout(timeout_expr);
Init(cond, body, timeout_stmts, frame, is_return, location);
}
Trigger::Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, double timeout, Frame* frame,
bool is_return, const Location* location)
{
timeout_value = timeout;
Init(cond, body, timeout_stmts, frame, is_return, location);
}
Trigger::Trigger(WhenInfo* wi, const IDSet& _globals, std::vector<ValPtr> _local_aggrs, Frame* f,
const Location* loc)
{
globals = _globals;
local_aggrs = std::move(_local_aggrs);
have_trigger_elems = true;
GetTimeout(wi->TimeoutExpr());
Init(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), f, wi->IsReturn(), loc);
}
void Trigger::GetTimeout(const ExprPtr& timeout_expr)
{
timeout_value = -1.0;
if ( timeout_expr )
{
@ -123,24 +144,14 @@ Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeou
if ( timeout_val )
timeout_value = timeout_val->AsInterval();
}
Init(cond, body, timeout_stmts, frame, is_return, location);
}
Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, double timeout, Frame* frame,
bool is_return, const Location* location)
{
timeout_value = timeout;
Init(cond, body, timeout_stmts, frame, is_return, location);
}
void Trigger::Init(const Expr* arg_cond, Stmt* arg_body, Stmt* arg_timeout_stmts, Frame* arg_frame,
void Trigger::Init(ExprPtr arg_cond, StmtPtr arg_body, StmtPtr arg_timeout_stmts, Frame* arg_frame,
bool arg_is_return, const Location* arg_location)
{
cond = arg_cond;
body = arg_body;
timeout_stmts = arg_timeout_stmts;
frame = arg_frame->Clone();
timer = nullptr;
delayed = false;
disabled = false;
@ -148,6 +159,11 @@ void Trigger::Init(const Expr* arg_cond, Stmt* arg_body, Stmt* arg_timeout_stmts
is_return = arg_is_return;
location = arg_location;
if ( arg_frame )
frame = arg_frame->Clone();
else
frame = nullptr;
DBG_LOG(DBG_NOTIFIERS, "%s: instantiating", Name());
if ( is_return )
@ -217,8 +233,30 @@ void Trigger::ReInit(std::vector<ValPtr> index_expr_results)
{
assert(! disabled);
UnregisterAll();
TriggerTraversalCallback cb(this);
cond->Traverse(&cb);
if ( ! have_trigger_elems )
{
TriggerTraversalCallback cb(globals, locals);
cond->Traverse(&cb);
have_trigger_elems = true;
}
for ( auto g : globals )
{
Register(g);
auto& v = g->GetVal();
if ( v && v->Modifiable() )
Register(v.get());
}
for ( auto l : locals )
{
ASSERT(! l->GetVal());
}
for ( auto& av : local_aggrs )
Register(av.get());
for ( const auto& v : index_expr_results )
Register(v.get());
@ -390,9 +428,10 @@ void Trigger::Timeout()
Unref(this);
}
void Trigger::Register(ID* id)
void Trigger::Register(const ID* const_id)
{
assert(! disabled);
ID* id = const_cast<ID*>(const_id);
notifier::detail::registry.Register(id, this);
Ref(id);