simplify WhenInfo and Trigger classes given removal of old capture semantics

This commit is contained in:
Vern Paxson 2023-04-02 11:36:39 -07:00
parent 4af6b52876
commit 84906171ba
4 changed files with 62 additions and 179 deletions

View file

@ -1867,7 +1867,8 @@ TraversalCode NullStmt::Traverse(TraversalCallback* cb) const
WhenInfo::WhenInfo(ExprPtr arg_cond, FuncType::CaptureList* arg_cl, bool arg_is_return)
: cond(std::move(arg_cond)), cl(arg_cl), is_return(arg_is_return)
{
prior_vars = current_scope()->Vars();
if ( ! cl )
cl = new zeek::FuncType::CaptureList;
ProfileFunc cond_pf(cond.get());
@ -1881,20 +1882,17 @@ WhenInfo::WhenInfo(ExprPtr arg_cond, FuncType::CaptureList* arg_cl, bool arg_is_
{
bool is_present = false;
if ( cl )
{
for ( auto& c : *cl )
if ( c.id == wl )
{
is_present = true;
break;
}
if ( ! is_present )
for ( auto& c : *cl )
if ( c.id == wl )
{
IDPtr wl_ptr = {NewRef{}, const_cast<ID*>(wl)};
cl->emplace_back(FuncType::Capture{wl_ptr, false});
is_present = true;
break;
}
if ( ! is_present )
{
IDPtr wl_ptr = {NewRef{}, const_cast<ID*>(wl)};
cl->emplace_back(FuncType::Capture{wl_ptr, false});
}
// In addition, don't treat them as external locals that
@ -1926,32 +1924,12 @@ WhenInfo::WhenInfo(ExprPtr arg_cond, FuncType::CaptureList* arg_cl, bool arg_is_
WhenInfo::WhenInfo(bool arg_is_return) : is_return(arg_is_return)
{
// This won't be needed once we remove the deprecated semantics.
cl = new zeek::FuncType::CaptureList;
BuildInvokeElems();
}
void WhenInfo::Build(StmtPtr ws)
{
// This will call ws->Error() if it's deprecated and we can
// short-circuit.
if ( IsDeprecatedSemantics(ws) )
return;
if ( ! cl )
{
// This instance is compatible with new-style semantics,
// so create a capture list for it and populate with any
// when-locals.
cl = new zeek::FuncType::CaptureList;
for ( auto& wl : when_new_locals )
{
IDPtr wl_ptr = {NewRef{}, const_cast<ID*>(wl)};
cl->emplace_back(FuncType::Capture{wl_ptr, false});
}
}
lambda_ft->SetCaptures(*cl);
// Our general strategy is to construct a single lambda (so that
@ -1996,38 +1974,30 @@ void WhenInfo::Build(StmtPtr ws)
auto shebang = make_intrusive<StmtList>(do_test, do_bodies, dummy_return);
auto ingredients = std::make_unique<function_ingredients>(current_scope(), shebang,
current_module);
auto outer_ids = gather_outer_ids(pop_scope(), ingredients->body);
auto ingredients = std::make_unique<FunctionIngredients>(current_scope(), shebang,
current_module);
auto outer_ids = gather_outer_ids(pop_scope(), ingredients->Body());
lambda = make_intrusive<LambdaExpr>(std::move(ingredients), std::move(outer_ids), ws);
}
void WhenInfo::Instantiate(Frame* f)
{
if ( cl )
Instantiate(lambda->Eval(f));
Instantiate(lambda->Eval(f));
}
void WhenInfo::Instantiate(ValPtr func)
{
if ( cl )
curr_lambda = make_intrusive<ConstExpr>(std::move(func));
curr_lambda = make_intrusive<ConstExpr>(std::move(func));
}
ExprPtr WhenInfo::Cond()
{
if ( ! curr_lambda )
return cond;
return make_intrusive<CallExpr>(curr_lambda, invoke_cond);
}
StmtPtr WhenInfo::WhenBody()
{
if ( ! curr_lambda )
return s;
auto invoke = make_intrusive<CallExpr>(curr_lambda, invoke_s);
return make_intrusive<ReturnStmt>(invoke, true);
}
@ -2046,61 +2016,10 @@ double WhenInfo::TimeoutVal(Frame* f)
StmtPtr WhenInfo::TimeoutStmt()
{
if ( ! curr_lambda )
return timeout_s;
auto invoke = make_intrusive<CallExpr>(curr_lambda, invoke_timeout);
return make_intrusive<ReturnStmt>(invoke, true);
}
bool WhenInfo::IsDeprecatedSemantics(StmtPtr ws)
{
if ( cl )
return false;
// Which locals of the outer function are used in any of the "when"
// elements.
IDSet locals;
for ( auto& wl : when_new_locals )
prior_vars.erase(wl->Name());
for ( auto& bl : when_expr_locals )
if ( prior_vars.count(bl->Name()) > 0 )
locals.insert(bl);
ProfileFunc body_pf(s.get());
for ( auto& bl : body_pf.Locals() )
if ( prior_vars.count(bl->Name()) > 0 )
locals.insert(bl);
if ( timeout_s )
{
ProfileFunc to_pf(timeout_s.get());
for ( auto& tl : to_pf.Locals() )
if ( prior_vars.count(tl->Name()) > 0 )
locals.insert(tl);
}
if ( locals.empty() )
return false;
std::string vars;
for ( auto& l : locals )
{
if ( ! vars.empty() )
vars += ", ";
vars += l->Name();
}
std::string msg = util::fmt("\"when\" statement referring to locals without an "
"explicit [] capture: %s",
vars.c_str());
ws->Error(msg.c_str());
return true;
}
void WhenInfo::BuildInvokeElems()
{
one_const = make_intrusive<ConstExpr>(val_mgr->Count(1));
@ -2116,12 +2035,12 @@ WhenStmt::WhenStmt(WhenInfo* arg_wi) : Stmt(STMT_WHEN), wi(arg_wi)
{
wi->Build(ThisPtr());
auto cond = wi->Cond();
auto cond = wi->OrigCond();
if ( ! cond->IsError() && ! IsBool(cond->GetType()->Tag()) )
cond->Error("conditional in test must be boolean");
auto te = wi->TimeoutExpr();
auto te = wi->OrigTimeout();
if ( te )
{
@ -2148,32 +2067,24 @@ ValPtr WhenStmt::Exec(Frame* f, StmtFlowType& flow)
auto timeout = wi->TimeoutVal(f);
if ( wi->Captures() )
std::vector<ValPtr> local_aggrs;
for ( auto& l : wi->WhenExprLocals() )
{
std::vector<ValPtr> local_aggrs;
for ( auto& l : wi->WhenExprLocals() )
{
IDPtr l_ptr = {NewRef{}, const_cast<ID*>(l)};
auto v = f->GetElementByID(l_ptr);
if ( v && v->Modifiable() )
local_aggrs.emplace_back(std::move(v));
}
new trigger::Trigger(wi, timeout, wi->WhenExprGlobals(), local_aggrs, f, location);
IDPtr l_ptr = {NewRef{}, const_cast<ID*>(l)};
auto v = f->GetElementByID(l_ptr);
if ( v && v->Modifiable() )
local_aggrs.emplace_back(std::move(v));
}
else
// The new trigger object will take care of its own deletion.
new trigger::Trigger(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), timeout, f,
wi->IsReturn(), location);
// The new trigger object will take care of its own deletion.
new trigger::Trigger(wi, timeout, wi->WhenExprGlobals(), local_aggrs, f, location);
return nullptr;
}
bool WhenStmt::IsPure() const
{
return wi->Cond()->IsPure() && wi->WhenBody()->IsPure() &&
(! wi->TimeoutStmt() || wi->TimeoutStmt()->IsPure());
return false;
}
void WhenStmt::StmtDescribe(ODesc* d) const
@ -2183,35 +2094,35 @@ void WhenStmt::StmtDescribe(ODesc* d) const
if ( d->IsReadable() )
d->Add("(");
wi->Cond()->Describe(d);
wi->OrigCond()->Describe(d);
if ( d->IsReadable() )
d->Add(")");
d->SP();
d->PushIndent();
wi->WhenBody()->AccessStats(d);
wi->WhenBody()->Describe(d);
wi->OrigBody()->AccessStats(d);
wi->OrigBody()->Describe(d);
d->PopIndent();
if ( wi->TimeoutExpr() )
if ( wi->OrigTimeout() )
{
if ( d->IsReadable() )
{
d->SP();
d->Add("timeout");
d->SP();
wi->TimeoutExpr()->Describe(d);
wi->OrigTimeout()->Describe(d);
d->SP();
d->PushIndent();
wi->TimeoutStmt()->AccessStats(d);
wi->TimeoutStmt()->Describe(d);
wi->OrigTimeoutStmt()->AccessStats(d);
wi->OrigTimeoutStmt()->Describe(d);
d->PopIndent();
}
else
{
wi->TimeoutExpr()->Describe(d);
wi->TimeoutStmt()->Describe(d);
wi->OrigTimeout()->Describe(d);
wi->OrigTimeoutStmt()->Describe(d);
}
}
}
@ -2231,22 +2142,22 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const
else
{
tc = wi->Cond()->Traverse(cb);
tc = wi->OrigCond()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
tc = wi->WhenBody()->Traverse(cb);
tc = wi->OrigBody()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
if ( wi->TimeoutStmt() )
if ( wi->OrigTimeoutStmt() )
{
tc = wi->TimeoutStmt()->Traverse(cb);
tc = wi->OrigTimeoutStmt()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
}
}
if ( wi->TimeoutExpr() )
if ( wi->OrigTimeout() )
{
tc = wi->TimeoutExpr()->Traverse(cb);
tc = wi->OrigTimeout()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
}