captures for "when" statements

update Triggers to IntrusivePtr's and simpler AST traversal
introduce IDSet type, migrate associated "ID*" types to "const ID*"
This commit is contained in:
Vern Paxson 2022-01-07 14:50:35 -08:00
parent fa142438fe
commit f895008c34
24 changed files with 648 additions and 202 deletions

View file

@ -20,6 +20,7 @@
#include "zeek/Var.h"
#include "zeek/logging/Manager.h"
#include "zeek/logging/logging.bif.h"
#include "zeek/script_opt/ProfileFunc.h"
#include "zeek/script_opt/StmtOptInfo.h"
namespace zeek::detail
@ -1558,7 +1559,7 @@ ReturnStmt::ReturnStmt(ExprPtr arg_e) : ExprStmt(STMT_RETURN, std::move(arg_e))
else if ( ! e )
{
if ( ft->Flavor() != FUNC_FLAVOR_HOOK )
if ( ft->Flavor() != FUNC_FLAVOR_HOOK && ! ft->ExpressionlessReturnOkay() )
Error("return statement needs expression");
}
@ -1794,45 +1795,257 @@ TraversalCode NullStmt::Traverse(TraversalCallback* cb) const
HANDLE_TC_STMT_POST(tc);
}
WhenStmt::WhenStmt(ExprPtr arg_cond, StmtPtr arg_s1, StmtPtr arg_s2, ExprPtr arg_timeout,
bool arg_is_return)
: Stmt(STMT_WHEN), cond(std::move(arg_cond)), s1(std::move(arg_s1)), s2(std::move(arg_s2)),
timeout(std::move(arg_timeout)), is_return(arg_is_return)
WhenInfo::WhenInfo(ExprPtr _cond, FuncType::CaptureList* _cl, bool _is_return)
: cond(std::move(_cond)), cl(_cl), is_return(_is_return)
{
assert(cond);
assert(s1);
prior_vars = current_scope()->Vars();
ProfileFunc cond_pf(cond.get());
if ( ! cl )
{
for ( auto& wl : cond_pf.WhenLocals() )
prior_vars.erase(wl->Name());
return;
}
when_expr_locals = cond_pf.Locals();
when_expr_globals = cond_pf.Globals();
// Make any when-locals part of our captures, if not already present,
// to enable sharing between the condition and the body/timeout code.
for ( auto& wl : cond_pf.WhenLocals() )
{
bool is_present = false;
for ( auto& c : *cl )
if ( c.id == wl )
{
is_present = true;
break;
}
if ( ! is_present )
{
IDPtr wl_ptr = {NewRef{}, const_cast<ID*>(wl)};
cl->emplace_back(FuncType::Capture{wl_ptr, false});
}
// In addition, don't treat them as external locals that
// existed at the onset.
when_expr_locals.erase(wl);
}
// Create the internal lambda we'll use to manage the captures.
static int num_params = 0; // to ensure each is distinct
lambda_param_id = util::fmt("when-param-%d", ++num_params);
auto param_list = new type_decl_list();
auto count_t = base_type(TYPE_COUNT);
param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t));
auto params = make_intrusive<RecordType>(param_list);
auto ft = make_intrusive<FuncType>(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION);
ft->SetCaptures(*cl);
if ( ! is_return )
ft->SetExpressionlessReturnOkay(true);
auto id = current_scope()->GenerateTemporary("when-internal");
// This begin_func will be completed by WhenInfo::Build().
begin_func(id, current_module.c_str(), FUNC_FLAVOR_FUNCTION, false, ft);
}
void WhenInfo::Build(StmtPtr ws)
{
if ( ! cl )
{
// Old-style semantics.
auto locals = when_expr_locals;
ProfileFunc cond_pf(cond.get());
for ( auto& bl : cond_pf.Locals() )
if ( prior_vars.count(bl->Name()) > 0 )
locals.insert(bl);
ProfileFunc body_pf(s.get());
for ( auto& bl : body_pf.Locals() )
if ( prior_vars.count(bl->Name()) > 0 )
locals.insert(bl);
if ( timeout_s )
{
ProfileFunc to_pf(timeout_s.get());
for ( auto& tl : to_pf.Locals() )
if ( prior_vars.count(tl->Name()) > 0 )
locals.insert(tl);
}
if ( ! locals.empty() )
{
std::string vars;
for ( auto& l : locals )
{
if ( ! vars.empty() )
vars += ", ";
vars += l->Name();
}
std::string msg = util::fmt("\"when\" statement referring to locals without an "
"explicit [] capture is deprecated: %s",
vars.c_str());
ws->Warn(msg.c_str());
}
return;
}
// Our general strategy is to construct a single lambda (so that
// the values of captures are shared across all of its elements)
// that's used for all three of the "when" components: condition,
// body, and timeout body. The idea is that the lambda is passed
// a single argument that specifies the particular functionality
// to execute (1 = condition, 2 = body, 3 = timeout). It gets tricky
// in that the condition needs to return a boolean, whereas the body
// and timeout *might* return a value (for "return when") constructs,
// or might not (for vanilla "when"). We address that issue by
// (1) making the return type be "any", and (2) introducing elsehwere
// the notion of functions marked as being allowed to have bare
// returns (no associated expression) even though they have a return
// type (to deal with the vanilla "when" case).
// Build the AST elements of the lambda.
// First, the constants we'll need.
auto true_const = make_intrusive<ConstExpr>(val_mgr->True());
auto one_const = make_intrusive<ConstExpr>(val_mgr->Count(1));
auto two_const = make_intrusive<ConstExpr>(val_mgr->Count(2));
auto three_const = make_intrusive<ConstExpr>(val_mgr->Count(3));
invoke_cond = make_intrusive<ListExpr>(one_const);
invoke_s = make_intrusive<ListExpr>(two_const);
invoke_timeout = make_intrusive<ListExpr>(three_const);
// Access to the parameter that selects which action we're doing.
auto param_id = lookup_ID(lambda_param_id.c_str(), current_module.c_str());
ASSERT(param_id);
auto param = make_intrusive<NameExpr>(param_id);
// Expressions for testing for the latter constants.
auto one_test = make_intrusive<EqExpr>(EXPR_EQ, param, one_const);
auto two_test = make_intrusive<EqExpr>(EXPR_EQ, param, two_const);
auto empty = make_intrusive<NullStmt>();
auto test_cond = make_intrusive<ReturnStmt>(cond);
auto do_test = make_intrusive<IfStmt>(one_test, test_cond, empty);
auto else_branch = timeout_s ? timeout_s : empty;
auto do_bodies = make_intrusive<IfStmt>(two_test, s, else_branch);
auto dummy_return = make_intrusive<ReturnStmt>(true_const);
auto shebang = make_intrusive<StmtList>(do_test, do_bodies, dummy_return);
auto ingredients = std::make_unique<function_ingredients>(current_scope(), shebang);
auto outer_ids = gather_outer_ids(pop_scope(), ingredients->body);
lambda = make_intrusive<LambdaExpr>(std::move(ingredients), std::move(outer_ids), ws);
}
void WhenInfo::Instantiate(Frame* f)
{
if ( cl )
curr_lambda = make_intrusive<ConstExpr>(lambda->Eval(f));
}
ExprPtr WhenInfo::Cond()
{
if ( ! curr_lambda )
return cond;
return make_intrusive<CallExpr>(curr_lambda, invoke_cond);
}
StmtPtr WhenInfo::WhenBody()
{
if ( ! curr_lambda )
return s;
auto invoke = make_intrusive<CallExpr>(curr_lambda, invoke_s);
return make_intrusive<ReturnStmt>(invoke, true);
}
StmtPtr WhenInfo::TimeoutStmt()
{
if ( ! curr_lambda )
return timeout_s;
auto invoke = make_intrusive<CallExpr>(curr_lambda, invoke_timeout);
return make_intrusive<ReturnStmt>(invoke, true);
}
WhenStmt::WhenStmt(WhenInfo* _wi) : Stmt(STMT_WHEN), wi(_wi)
{
wi->Build(ThisPtr());
auto cond = wi->Cond();
if ( ! cond->IsError() && ! IsBool(cond->GetType()->Tag()) )
cond->Error("conditional in test must be boolean");
if ( timeout )
auto te = wi->TimeoutExpr();
if ( te )
{
if ( timeout->IsError() )
if ( te->IsError() )
return;
TypeTag bt = timeout->GetType()->Tag();
TypeTag bt = te->GetType()->Tag();
if ( bt != TYPE_TIME && bt != TYPE_INTERVAL )
cond->Error("when timeout requires a time or time interval");
te->Error("when timeout requires a time or time interval");
}
}
WhenStmt::~WhenStmt() = default;
WhenStmt::~WhenStmt()
{
delete wi;
}
ValPtr WhenStmt::Exec(Frame* f, StmtFlowType& flow)
{
RegisterAccess();
flow = FLOW_NEXT;
// The new trigger object will take care of its own deletion.
new trigger::Trigger(IntrusivePtr{cond}.release(), IntrusivePtr{s1}.release(),
IntrusivePtr{s2}.release(), IntrusivePtr{timeout}.release(), f, is_return,
location);
wi->Instantiate(f);
if ( wi->Captures() )
{
std::vector<ValPtr> local_aggrs;
for ( auto& l : wi->WhenExprLocals() )
{
IDPtr l_ptr = {NewRef{}, const_cast<ID*>(l)};
auto v = f->GetElementByID(l_ptr);
if ( v && v->Modifiable() )
local_aggrs.emplace_back(std::move(v));
}
new trigger::Trigger(wi, wi->WhenExprGlobals(), local_aggrs, f, location);
}
else
// The new trigger object will take care of its own deletion.
new trigger::Trigger(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), wi->TimeoutExpr(), f,
wi->IsReturn(), location);
return nullptr;
}
bool WhenStmt::IsPure() const
{
return cond->IsPure() && s1->IsPure() && (! s2 || s2->IsPure());
return wi->Cond()->IsPure() && wi->WhenBody()->IsPure() &&
(! wi->TimeoutStmt() || wi->TimeoutStmt()->IsPure());
}
void WhenStmt::StmtDescribe(ODesc* d) const
@ -1842,33 +2055,33 @@ void WhenStmt::StmtDescribe(ODesc* d) const
if ( d->IsReadable() )
d->Add("(");
cond->Describe(d);
wi->Cond()->Describe(d);
if ( d->IsReadable() )
d->Add(")");
d->SP();
d->PushIndent();
s1->AccessStats(d);
s1->Describe(d);
wi->WhenBody()->AccessStats(d);
wi->WhenBody()->Describe(d);
d->PopIndent();
if ( s2 )
if ( wi->TimeoutStmt() )
{
if ( d->IsReadable() )
{
d->SP();
d->Add("timeout");
d->SP();
timeout->Describe(d);
wi->TimeoutExpr()->Describe(d);
d->SP();
d->PushIndent();
s2->AccessStats(d);
s2->Describe(d);
wi->TimeoutStmt()->AccessStats(d);
wi->TimeoutStmt()->Describe(d);
d->PopIndent();
}
else
s2->Describe(d);
wi->TimeoutStmt()->Describe(d);
}
}
@ -1877,15 +2090,15 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const
TraversalCode tc = cb->PreStmt(this);
HANDLE_TC_STMT_PRE(tc);
tc = cond->Traverse(cb);
tc = wi->Cond()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
tc = s1->Traverse(cb);
tc = wi->WhenBody()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
if ( s2 )
if ( wi->TimeoutStmt() )
{
tc = s2->Traverse(cb);
tc = wi->TimeoutStmt()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
}