mirror of
https://github.com/zeek/zeek.git
synced 2025-10-10 10:38:20 +00:00
captures for "when" statements
update Triggers to IntrusivePtr's and simpler AST traversal introduce IDSet type, migrate associated "ID*" types to "const ID*"
This commit is contained in:
parent
fa142438fe
commit
f895008c34
24 changed files with 648 additions and 202 deletions
271
src/Stmt.cc
271
src/Stmt.cc
|
@ -20,6 +20,7 @@
|
|||
#include "zeek/Var.h"
|
||||
#include "zeek/logging/Manager.h"
|
||||
#include "zeek/logging/logging.bif.h"
|
||||
#include "zeek/script_opt/ProfileFunc.h"
|
||||
#include "zeek/script_opt/StmtOptInfo.h"
|
||||
|
||||
namespace zeek::detail
|
||||
|
@ -1558,7 +1559,7 @@ ReturnStmt::ReturnStmt(ExprPtr arg_e) : ExprStmt(STMT_RETURN, std::move(arg_e))
|
|||
|
||||
else if ( ! e )
|
||||
{
|
||||
if ( ft->Flavor() != FUNC_FLAVOR_HOOK )
|
||||
if ( ft->Flavor() != FUNC_FLAVOR_HOOK && ! ft->ExpressionlessReturnOkay() )
|
||||
Error("return statement needs expression");
|
||||
}
|
||||
|
||||
|
@ -1794,45 +1795,257 @@ TraversalCode NullStmt::Traverse(TraversalCallback* cb) const
|
|||
HANDLE_TC_STMT_POST(tc);
|
||||
}
|
||||
|
||||
WhenStmt::WhenStmt(ExprPtr arg_cond, StmtPtr arg_s1, StmtPtr arg_s2, ExprPtr arg_timeout,
|
||||
bool arg_is_return)
|
||||
: Stmt(STMT_WHEN), cond(std::move(arg_cond)), s1(std::move(arg_s1)), s2(std::move(arg_s2)),
|
||||
timeout(std::move(arg_timeout)), is_return(arg_is_return)
|
||||
WhenInfo::WhenInfo(ExprPtr _cond, FuncType::CaptureList* _cl, bool _is_return)
|
||||
: cond(std::move(_cond)), cl(_cl), is_return(_is_return)
|
||||
{
|
||||
assert(cond);
|
||||
assert(s1);
|
||||
prior_vars = current_scope()->Vars();
|
||||
|
||||
ProfileFunc cond_pf(cond.get());
|
||||
|
||||
if ( ! cl )
|
||||
{
|
||||
for ( auto& wl : cond_pf.WhenLocals() )
|
||||
prior_vars.erase(wl->Name());
|
||||
return;
|
||||
}
|
||||
|
||||
when_expr_locals = cond_pf.Locals();
|
||||
when_expr_globals = cond_pf.Globals();
|
||||
|
||||
// Make any when-locals part of our captures, if not already present,
|
||||
// to enable sharing between the condition and the body/timeout code.
|
||||
for ( auto& wl : cond_pf.WhenLocals() )
|
||||
{
|
||||
bool is_present = false;
|
||||
|
||||
for ( auto& c : *cl )
|
||||
if ( c.id == wl )
|
||||
{
|
||||
is_present = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if ( ! is_present )
|
||||
{
|
||||
IDPtr wl_ptr = {NewRef{}, const_cast<ID*>(wl)};
|
||||
cl->emplace_back(FuncType::Capture{wl_ptr, false});
|
||||
}
|
||||
|
||||
// In addition, don't treat them as external locals that
|
||||
// existed at the onset.
|
||||
when_expr_locals.erase(wl);
|
||||
}
|
||||
|
||||
// Create the internal lambda we'll use to manage the captures.
|
||||
static int num_params = 0; // to ensure each is distinct
|
||||
lambda_param_id = util::fmt("when-param-%d", ++num_params);
|
||||
|
||||
auto param_list = new type_decl_list();
|
||||
auto count_t = base_type(TYPE_COUNT);
|
||||
param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t));
|
||||
auto params = make_intrusive<RecordType>(param_list);
|
||||
|
||||
auto ft = make_intrusive<FuncType>(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION);
|
||||
ft->SetCaptures(*cl);
|
||||
|
||||
if ( ! is_return )
|
||||
ft->SetExpressionlessReturnOkay(true);
|
||||
|
||||
auto id = current_scope()->GenerateTemporary("when-internal");
|
||||
|
||||
// This begin_func will be completed by WhenInfo::Build().
|
||||
begin_func(id, current_module.c_str(), FUNC_FLAVOR_FUNCTION, false, ft);
|
||||
}
|
||||
|
||||
void WhenInfo::Build(StmtPtr ws)
|
||||
{
|
||||
if ( ! cl )
|
||||
{
|
||||
// Old-style semantics.
|
||||
auto locals = when_expr_locals;
|
||||
|
||||
ProfileFunc cond_pf(cond.get());
|
||||
for ( auto& bl : cond_pf.Locals() )
|
||||
if ( prior_vars.count(bl->Name()) > 0 )
|
||||
locals.insert(bl);
|
||||
|
||||
ProfileFunc body_pf(s.get());
|
||||
for ( auto& bl : body_pf.Locals() )
|
||||
if ( prior_vars.count(bl->Name()) > 0 )
|
||||
locals.insert(bl);
|
||||
|
||||
if ( timeout_s )
|
||||
{
|
||||
ProfileFunc to_pf(timeout_s.get());
|
||||
for ( auto& tl : to_pf.Locals() )
|
||||
if ( prior_vars.count(tl->Name()) > 0 )
|
||||
locals.insert(tl);
|
||||
}
|
||||
|
||||
if ( ! locals.empty() )
|
||||
{
|
||||
std::string vars;
|
||||
for ( auto& l : locals )
|
||||
{
|
||||
if ( ! vars.empty() )
|
||||
vars += ", ";
|
||||
vars += l->Name();
|
||||
}
|
||||
|
||||
std::string msg = util::fmt("\"when\" statement referring to locals without an "
|
||||
"explicit [] capture is deprecated: %s",
|
||||
vars.c_str());
|
||||
ws->Warn(msg.c_str());
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Our general strategy is to construct a single lambda (so that
|
||||
// the values of captures are shared across all of its elements)
|
||||
// that's used for all three of the "when" components: condition,
|
||||
// body, and timeout body. The idea is that the lambda is passed
|
||||
// a single argument that specifies the particular functionality
|
||||
// to execute (1 = condition, 2 = body, 3 = timeout). It gets tricky
|
||||
// in that the condition needs to return a boolean, whereas the body
|
||||
// and timeout *might* return a value (for "return when") constructs,
|
||||
// or might not (for vanilla "when"). We address that issue by
|
||||
// (1) making the return type be "any", and (2) introducing elsehwere
|
||||
// the notion of functions marked as being allowed to have bare
|
||||
// returns (no associated expression) even though they have a return
|
||||
// type (to deal with the vanilla "when" case).
|
||||
|
||||
// Build the AST elements of the lambda.
|
||||
|
||||
// First, the constants we'll need.
|
||||
auto true_const = make_intrusive<ConstExpr>(val_mgr->True());
|
||||
auto one_const = make_intrusive<ConstExpr>(val_mgr->Count(1));
|
||||
auto two_const = make_intrusive<ConstExpr>(val_mgr->Count(2));
|
||||
auto three_const = make_intrusive<ConstExpr>(val_mgr->Count(3));
|
||||
|
||||
invoke_cond = make_intrusive<ListExpr>(one_const);
|
||||
invoke_s = make_intrusive<ListExpr>(two_const);
|
||||
invoke_timeout = make_intrusive<ListExpr>(three_const);
|
||||
|
||||
// Access to the parameter that selects which action we're doing.
|
||||
auto param_id = lookup_ID(lambda_param_id.c_str(), current_module.c_str());
|
||||
ASSERT(param_id);
|
||||
auto param = make_intrusive<NameExpr>(param_id);
|
||||
|
||||
// Expressions for testing for the latter constants.
|
||||
auto one_test = make_intrusive<EqExpr>(EXPR_EQ, param, one_const);
|
||||
auto two_test = make_intrusive<EqExpr>(EXPR_EQ, param, two_const);
|
||||
|
||||
auto empty = make_intrusive<NullStmt>();
|
||||
|
||||
auto test_cond = make_intrusive<ReturnStmt>(cond);
|
||||
auto do_test = make_intrusive<IfStmt>(one_test, test_cond, empty);
|
||||
|
||||
auto else_branch = timeout_s ? timeout_s : empty;
|
||||
|
||||
auto do_bodies = make_intrusive<IfStmt>(two_test, s, else_branch);
|
||||
auto dummy_return = make_intrusive<ReturnStmt>(true_const);
|
||||
|
||||
auto shebang = make_intrusive<StmtList>(do_test, do_bodies, dummy_return);
|
||||
|
||||
auto ingredients = std::make_unique<function_ingredients>(current_scope(), shebang);
|
||||
auto outer_ids = gather_outer_ids(pop_scope(), ingredients->body);
|
||||
|
||||
lambda = make_intrusive<LambdaExpr>(std::move(ingredients), std::move(outer_ids), ws);
|
||||
}
|
||||
|
||||
void WhenInfo::Instantiate(Frame* f)
|
||||
{
|
||||
if ( cl )
|
||||
curr_lambda = make_intrusive<ConstExpr>(lambda->Eval(f));
|
||||
}
|
||||
|
||||
ExprPtr WhenInfo::Cond()
|
||||
{
|
||||
if ( ! curr_lambda )
|
||||
return cond;
|
||||
|
||||
return make_intrusive<CallExpr>(curr_lambda, invoke_cond);
|
||||
}
|
||||
|
||||
StmtPtr WhenInfo::WhenBody()
|
||||
{
|
||||
if ( ! curr_lambda )
|
||||
return s;
|
||||
|
||||
auto invoke = make_intrusive<CallExpr>(curr_lambda, invoke_s);
|
||||
return make_intrusive<ReturnStmt>(invoke, true);
|
||||
}
|
||||
|
||||
StmtPtr WhenInfo::TimeoutStmt()
|
||||
{
|
||||
if ( ! curr_lambda )
|
||||
return timeout_s;
|
||||
|
||||
auto invoke = make_intrusive<CallExpr>(curr_lambda, invoke_timeout);
|
||||
return make_intrusive<ReturnStmt>(invoke, true);
|
||||
}
|
||||
|
||||
WhenStmt::WhenStmt(WhenInfo* _wi) : Stmt(STMT_WHEN), wi(_wi)
|
||||
{
|
||||
wi->Build(ThisPtr());
|
||||
|
||||
auto cond = wi->Cond();
|
||||
|
||||
if ( ! cond->IsError() && ! IsBool(cond->GetType()->Tag()) )
|
||||
cond->Error("conditional in test must be boolean");
|
||||
|
||||
if ( timeout )
|
||||
auto te = wi->TimeoutExpr();
|
||||
|
||||
if ( te )
|
||||
{
|
||||
if ( timeout->IsError() )
|
||||
if ( te->IsError() )
|
||||
return;
|
||||
|
||||
TypeTag bt = timeout->GetType()->Tag();
|
||||
TypeTag bt = te->GetType()->Tag();
|
||||
if ( bt != TYPE_TIME && bt != TYPE_INTERVAL )
|
||||
cond->Error("when timeout requires a time or time interval");
|
||||
te->Error("when timeout requires a time or time interval");
|
||||
}
|
||||
}
|
||||
|
||||
WhenStmt::~WhenStmt() = default;
|
||||
WhenStmt::~WhenStmt()
|
||||
{
|
||||
delete wi;
|
||||
}
|
||||
|
||||
ValPtr WhenStmt::Exec(Frame* f, StmtFlowType& flow)
|
||||
{
|
||||
RegisterAccess();
|
||||
flow = FLOW_NEXT;
|
||||
|
||||
// The new trigger object will take care of its own deletion.
|
||||
new trigger::Trigger(IntrusivePtr{cond}.release(), IntrusivePtr{s1}.release(),
|
||||
IntrusivePtr{s2}.release(), IntrusivePtr{timeout}.release(), f, is_return,
|
||||
location);
|
||||
wi->Instantiate(f);
|
||||
|
||||
if ( wi->Captures() )
|
||||
{
|
||||
std::vector<ValPtr> local_aggrs;
|
||||
for ( auto& l : wi->WhenExprLocals() )
|
||||
{
|
||||
IDPtr l_ptr = {NewRef{}, const_cast<ID*>(l)};
|
||||
auto v = f->GetElementByID(l_ptr);
|
||||
if ( v && v->Modifiable() )
|
||||
local_aggrs.emplace_back(std::move(v));
|
||||
}
|
||||
|
||||
new trigger::Trigger(wi, wi->WhenExprGlobals(), local_aggrs, f, location);
|
||||
}
|
||||
|
||||
else
|
||||
// The new trigger object will take care of its own deletion.
|
||||
new trigger::Trigger(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), wi->TimeoutExpr(), f,
|
||||
wi->IsReturn(), location);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool WhenStmt::IsPure() const
|
||||
{
|
||||
return cond->IsPure() && s1->IsPure() && (! s2 || s2->IsPure());
|
||||
return wi->Cond()->IsPure() && wi->WhenBody()->IsPure() &&
|
||||
(! wi->TimeoutStmt() || wi->TimeoutStmt()->IsPure());
|
||||
}
|
||||
|
||||
void WhenStmt::StmtDescribe(ODesc* d) const
|
||||
|
@ -1842,33 +2055,33 @@ void WhenStmt::StmtDescribe(ODesc* d) const
|
|||
if ( d->IsReadable() )
|
||||
d->Add("(");
|
||||
|
||||
cond->Describe(d);
|
||||
wi->Cond()->Describe(d);
|
||||
|
||||
if ( d->IsReadable() )
|
||||
d->Add(")");
|
||||
|
||||
d->SP();
|
||||
d->PushIndent();
|
||||
s1->AccessStats(d);
|
||||
s1->Describe(d);
|
||||
wi->WhenBody()->AccessStats(d);
|
||||
wi->WhenBody()->Describe(d);
|
||||
d->PopIndent();
|
||||
|
||||
if ( s2 )
|
||||
if ( wi->TimeoutStmt() )
|
||||
{
|
||||
if ( d->IsReadable() )
|
||||
{
|
||||
d->SP();
|
||||
d->Add("timeout");
|
||||
d->SP();
|
||||
timeout->Describe(d);
|
||||
wi->TimeoutExpr()->Describe(d);
|
||||
d->SP();
|
||||
d->PushIndent();
|
||||
s2->AccessStats(d);
|
||||
s2->Describe(d);
|
||||
wi->TimeoutStmt()->AccessStats(d);
|
||||
wi->TimeoutStmt()->Describe(d);
|
||||
d->PopIndent();
|
||||
}
|
||||
else
|
||||
s2->Describe(d);
|
||||
wi->TimeoutStmt()->Describe(d);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1877,15 +2090,15 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const
|
|||
TraversalCode tc = cb->PreStmt(this);
|
||||
HANDLE_TC_STMT_PRE(tc);
|
||||
|
||||
tc = cond->Traverse(cb);
|
||||
tc = wi->Cond()->Traverse(cb);
|
||||
HANDLE_TC_STMT_PRE(tc);
|
||||
|
||||
tc = s1->Traverse(cb);
|
||||
tc = wi->WhenBody()->Traverse(cb);
|
||||
HANDLE_TC_STMT_PRE(tc);
|
||||
|
||||
if ( s2 )
|
||||
if ( wi->TimeoutStmt() )
|
||||
{
|
||||
tc = s2->Traverse(cb);
|
||||
tc = wi->TimeoutStmt()->Traverse(cb);
|
||||
HANDLE_TC_STMT_PRE(tc);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue