WhenStmt/WhenInfo restructuring in support of ZAM "when" statements

This commit is contained in:
Vern Paxson 2023-06-16 16:10:43 -07:00 committed by Arne Welzel
parent 7d5760ac74
commit 1dd2270272
5 changed files with 200 additions and 90 deletions

View file

@ -2005,9 +2005,68 @@ WhenInfo::WhenInfo(ExprPtr arg_cond, FuncType::CaptureList* arg_cl, bool arg_is_
if ( ! cl ) if ( ! cl )
cl = new zeek::FuncType::CaptureList; cl = new zeek::FuncType::CaptureList;
BuildProfile();
// Create the internal lambda we'll use to manage the captures.
static int num_params = 0; // to ensure each is distinct
lambda_param_id = util::fmt("when-param-%d", ++num_params);
auto param_list = new type_decl_list();
auto count_t = base_type(TYPE_COUNT);
param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t));
auto params = make_intrusive<RecordType>(param_list);
lambda_ft = make_intrusive<FuncType>(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION);
if ( ! is_return )
lambda_ft->SetExpressionlessReturnOkay(true);
lambda_ft->SetCaptures(*cl);
auto id = current_scope()->GenerateTemporary("when-internal");
id->SetType(lambda_ft);
push_scope(std::move(id), nullptr);
param_id = install_ID(lambda_param_id.c_str(), current_module.c_str(), false, false);
param_id->SetType(count_t);
}
WhenInfo::WhenInfo(const WhenInfo* orig)
{
if ( orig->cl )
{
cl = new FuncType::CaptureList;
*cl = *orig->cl;
}
cond = orig->OrigCond()->Duplicate();
// We don't duplicate these, as they'll be compiled separately.
s = orig->OrigBody();
timeout_s = orig->OrigBody();
timeout = orig->OrigTimeout();
if ( timeout )
timeout = timeout->Duplicate();
lambda = cast_intrusive<LambdaExpr>(orig->Lambda()->Duplicate());
is_return = orig->IsReturn();
BuildProfile();
}
WhenInfo::WhenInfo(bool arg_is_return) : is_return(arg_is_return)
{
cl = new zeek::FuncType::CaptureList;
BuildInvokeElems();
}
void WhenInfo::BuildProfile()
{
ProfileFunc cond_pf(cond.get()); ProfileFunc cond_pf(cond.get());
when_expr_locals = cond_pf.Locals(); auto when_expr_locals_set = cond_pf.Locals();
when_expr_globals = cond_pf.AllGlobals(); when_expr_globals = cond_pf.AllGlobals();
when_new_locals = cond_pf.WhenLocals(); when_new_locals = cond_pf.WhenLocals();
@ -2032,41 +2091,20 @@ WhenInfo::WhenInfo(ExprPtr arg_cond, FuncType::CaptureList* arg_cl, bool arg_is_
// In addition, don't treat them as external locals that // In addition, don't treat them as external locals that
// existed at the onset. // existed at the onset.
when_expr_locals.erase(wl); when_expr_locals_set.erase(wl);
} }
// Create the internal lambda we'll use to manage the captures. for ( auto& w : when_expr_locals_set )
static int num_params = 0; // to ensure each is distinct {
lambda_param_id = util::fmt("when-param-%d", ++num_params); // We need IDPtr versions of the locals so we can manipulate
// them during script optimization.
auto param_list = new type_decl_list(); auto non_const_w = const_cast<ID*>(w);
auto count_t = base_type(TYPE_COUNT); when_expr_locals.push_back({NewRef{}, non_const_w});
param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t)); }
auto params = make_intrusive<RecordType>(param_list);
lambda_ft = make_intrusive<FuncType>(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION);
if ( ! is_return )
lambda_ft->SetExpressionlessReturnOkay(true);
auto id = current_scope()->GenerateTemporary("when-internal");
id->SetType(lambda_ft);
push_scope(std::move(id), nullptr);
auto arg_id = install_ID(lambda_param_id.c_str(), current_module.c_str(), false, false);
arg_id->SetType(count_t);
}
WhenInfo::WhenInfo(bool arg_is_return) : is_return(arg_is_return)
{
cl = new zeek::FuncType::CaptureList;
BuildInvokeElems();
} }
void WhenInfo::Build(StmtPtr ws) void WhenInfo::Build(StmtPtr ws)
{ {
lambda_ft->SetCaptures(*cl);
// Our general strategy is to construct a single lambda (so that // Our general strategy is to construct a single lambda (so that
// the values of captures are shared across all of its elements) // the values of captures are shared across all of its elements)
// that's used for all three of the "when" components: condition, // that's used for all three of the "when" components: condition,
@ -2086,10 +2124,13 @@ void WhenInfo::Build(StmtPtr ws)
// First, the constants we'll need. // First, the constants we'll need.
BuildInvokeElems(); BuildInvokeElems();
if ( lambda )
// No need to build the lambda.
return;
auto true_const = make_intrusive<ConstExpr>(val_mgr->True()); auto true_const = make_intrusive<ConstExpr>(val_mgr->True());
// Access to the parameter that selects which action we're doing. // Access to the parameter that selects which action we're doing.
auto param_id = lookup_ID(lambda_param_id.c_str(), current_module.c_str());
ASSERT(param_id); ASSERT(param_id);
auto param = make_intrusive<NameExpr>(param_id); auto param = make_intrusive<NameExpr>(param_id);
@ -2109,11 +2150,14 @@ void WhenInfo::Build(StmtPtr ws)
auto shebang = make_intrusive<StmtList>(do_test, do_bodies, dummy_return); auto shebang = make_intrusive<StmtList>(do_test, do_bodies, dummy_return);
auto ingredients = std::make_unique<FunctionIngredients>(current_scope(), shebang, auto ingredients = std::make_shared<FunctionIngredients>(current_scope(), shebang,
current_module); current_module);
auto outer_ids = gather_outer_ids(pop_scope(), ingredients->Body()); auto outer_ids = gather_outer_ids(pop_scope(), ingredients->Body());
lambda = make_intrusive<LambdaExpr>(std::move(ingredients), std::move(outer_ids), ws); lambda = make_intrusive<LambdaExpr>(std::move(ingredients), std::move(outer_ids), "", ws);
lambda->SetPrivateCaptures(when_new_locals);
analyze_when_lambda(lambda.get());
} }
void WhenInfo::Instantiate(Frame* f) void WhenInfo::Instantiate(Frame* f)
@ -2205,8 +2249,7 @@ ValPtr WhenStmt::Exec(Frame* f, StmtFlowType& flow)
std::vector<ValPtr> local_aggrs; std::vector<ValPtr> local_aggrs;
for ( auto& l : wi->WhenExprLocals() ) for ( auto& l : wi->WhenExprLocals() )
{ {
IDPtr l_ptr = {NewRef{}, const_cast<ID*>(l)}; auto v = f->GetElementByID(l);
auto v = f->GetElementByID(l_ptr);
if ( v && v->Modifiable() ) if ( v && v->Modifiable() )
local_aggrs.emplace_back(std::move(v)); local_aggrs.emplace_back(std::move(v));
} }
@ -2226,6 +2269,23 @@ void WhenStmt::StmtDescribe(ODesc* d) const
{ {
Stmt::StmtDescribe(d); Stmt::StmtDescribe(d);
auto cl = wi->Captures();
if ( d->IsReadable() && ! cl->empty() )
{
d->Add("[");
for ( auto& c : *cl )
{
if ( &c != &(*cl)[0] )
d->AddSP(",");
if ( c.IsDeepCopy() )
d->Add("copy ");
d->Add(c.Id()->Name());
}
d->Add("]");
}
if ( d->IsReadable() ) if ( d->IsReadable() )
d->Add("("); d->Add("(");
@ -2267,32 +2327,13 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const
TraversalCode tc = cb->PreStmt(this); TraversalCode tc = cb->PreStmt(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
auto wl = wi->Lambda(); tc = wi->Lambda()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
if ( wl ) auto e = wi->TimeoutExpr();
if ( e )
{ {
tc = wl->Traverse(cb); tc = e->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
}
else
{
tc = wi->OrigCond()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
tc = wi->OrigBody()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
if ( wi->OrigTimeoutStmt() )
{
tc = wi->OrigTimeoutStmt()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc);
}
}
if ( wi->OrigTimeout() )
{
tc = wi->OrigTimeout()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
} }

View file

@ -576,6 +576,9 @@ public:
// Takes ownership of the CaptureList. // Takes ownership of the CaptureList.
WhenInfo(ExprPtr cond, FuncType::CaptureList* cl, bool is_return); WhenInfo(ExprPtr cond, FuncType::CaptureList* cl, bool is_return);
// Used for duplication to support inlining.
WhenInfo(const WhenInfo* orig);
// Constructor used by script optimization to create a stub. // Constructor used by script optimization to create a stub.
WhenInfo(bool is_return); WhenInfo(bool is_return);
@ -616,6 +619,7 @@ public:
StmtPtr TimeoutStmt(); StmtPtr TimeoutStmt();
ExprPtr TimeoutExpr() const { return timeout; } ExprPtr TimeoutExpr() const { return timeout; }
void SetTimeoutExpr(ExprPtr e) { timeout = std::move(e); }
double TimeoutVal(Frame* f); double TimeoutVal(Frame* f);
FuncType::CaptureList* Captures() { return cl; } FuncType::CaptureList* Captures() { return cl; }
@ -625,10 +629,22 @@ public:
// The locals and globals used in the conditional expression // The locals and globals used in the conditional expression
// (other than newly introduced locals), necessary for registering // (other than newly introduced locals), necessary for registering
// the associated triggers for when their values change. // the associated triggers for when their values change.
const IDSet& WhenExprLocals() const { return when_expr_locals; } const auto& WhenExprLocals() const { return when_expr_locals; }
const IDSet& WhenExprGlobals() const { return when_expr_globals; } const auto& WhenExprGlobals() const { return when_expr_globals; }
// The locals introduced in the conditional expression.
const auto& WhenNewLocals() const { return when_new_locals; }
// Used for script optimization when in-lining needs to revise
// identifiers.
bool HasUnreducedIDs(Reducer* c) const;
void UpdateIDs(Reducer* c);
private: private:
// Profile the original AST elements to extract things like
// globals and locals used.
void BuildProfile();
// Build those elements we'll need for invoking our lambda. // Build those elements we'll need for invoking our lambda.
void BuildInvokeElems(); void BuildInvokeElems();
@ -640,8 +656,10 @@ private:
bool is_return = false; bool is_return = false;
// The name of parameter passed to the lambda. // The name of parameter passed to the lambda, and the corresponding
// identifier.
std::string lambda_param_id; std::string lambda_param_id;
IDPtr param_id;
// The expression for constructing the lambda, and its type. // The expression for constructing the lambda, and its type.
LambdaExprPtr lambda; LambdaExprPtr lambda;
@ -662,7 +680,7 @@ private:
ConstExprPtr two_const; ConstExprPtr two_const;
ConstExprPtr three_const; ConstExprPtr three_const;
IDSet when_expr_locals; std::vector<IDPtr> when_expr_locals;
IDSet when_expr_globals; IDSet when_expr_globals;
// Locals introduced via "local" in the "when" clause itself. // Locals introduced via "local" in the "when" clause itself.
@ -685,7 +703,7 @@ public:
StmtPtr TimeoutBody() const { return wi->TimeoutStmt(); } StmtPtr TimeoutBody() const { return wi->TimeoutStmt(); }
bool IsReturn() const { return wi->IsReturn(); } bool IsReturn() const { return wi->IsReturn(); }
const WhenInfo* Info() const { return wi; } WhenInfo* Info() const { return wi; }
void StmtDescribe(ODesc* d) const override; void StmtDescribe(ODesc* d) const override;
@ -693,9 +711,9 @@ public:
// Optimization-related: // Optimization-related:
StmtPtr Duplicate() override; StmtPtr Duplicate() override;
void Inline(Inliner* inl) override;
bool IsReduced(Reducer* c) const override; bool IsReduced(Reducer* c) const override;
StmtPtr DoReduce(Reducer* c) override;
private: private:
WhenInfo* wi; WhenInfo* wi;

View file

@ -768,11 +768,10 @@ TraversalCode OuterIDBindingFinder::PreStmt(const Stmt* stmt)
if ( stmt->Tag() != STMT_WHEN ) if ( stmt->Tag() != STMT_WHEN )
return TC_CONTINUE; return TC_CONTINUE;
// The semantics of identifiers for the "when" statement are those
// of the lambda it's transformed into.
auto ws = static_cast<const WhenStmt*>(stmt); auto ws = static_cast<const WhenStmt*>(stmt);
ws->Info()->Lambda()->Traverse(this);
for ( auto& cl : ws->Info()->WhenExprLocals() )
outer_id_references.insert(const_cast<ID*>(cl.get()));
return TC_ABORTSTMT; return TC_ABORTSTMT;
} }
@ -789,18 +788,19 @@ TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr)
if ( expr->Tag() != EXPR_NAME ) if ( expr->Tag() != EXPR_NAME )
return TC_CONTINUE; return TC_CONTINUE;
auto* e = static_cast<const NameExpr*>(expr); auto e = static_cast<const NameExpr*>(expr);
auto id = e->Id();
if ( e->Id()->IsGlobal() ) if ( id->IsGlobal() )
return TC_CONTINUE; return TC_CONTINUE;
for ( const auto& scope : scopes ) for ( const auto& scope : scopes )
if ( scope->Find(e->Id()->Name()) ) if ( scope->Find(id->Name()) )
// Shadowing is not allowed, so if it's found at inner scope, it's // Shadowing is not allowed, so if it's found at inner scope, it's
// not something we have to worry about also being at outer scope. // not something we have to worry about also being at outer scope.
return TC_CONTINUE; return TC_CONTINUE;
outer_id_references.insert(e->Id()); outer_id_references.insert(id);
return TC_CONTINUE; return TC_CONTINUE;
} }

View file

@ -2342,7 +2342,7 @@ ExprPtr CallExpr::Duplicate()
auto func_type = func->GetType(); auto func_type = func->GetType();
auto in_hook = func_type->AsFuncType()->Flavor() == FUNC_FLAVOR_HOOK; auto in_hook = func_type->AsFuncType()->Flavor() == FUNC_FLAVOR_HOOK;
return SetSucc(new CallExpr(func_d, args_d, in_hook)); return SetSucc(new CallExpr(func_d, args_d, in_hook, in_when));
} }
ExprPtr CallExpr::Inline(Inliner* inl) ExprPtr CallExpr::Inline(Inliner* inl)

View file

@ -939,34 +939,85 @@ StmtPtr AssertStmt::DoReduce(Reducer* c)
return make_intrusive<NullStmt>(); return make_intrusive<NullStmt>();
} }
StmtPtr WhenStmt::Duplicate() bool WhenInfo::HasUnreducedIDs(Reducer* c) const
{ {
FuncType::CaptureList* cl_dup = nullptr; for ( auto& cp : *cl )
if ( wi->Captures() )
{ {
cl_dup = new FuncType::CaptureList; auto cid = cp.Id();
*cl_dup = *wi->Captures();
if ( when_new_locals.count(cid.get()) == 0 && ! c->ID_IsReduced(cp.Id()) )
return true;
} }
auto new_wi = new WhenInfo(Cond(), cl_dup, IsReturn()); for ( auto& l : when_expr_locals )
new_wi->AddBody(Body()); if ( ! c->ID_IsReduced(l) )
new_wi->AddTimeout(TimeoutExpr(), TimeoutBody()); return true;
return SetSucc(new WhenStmt(wi)); return false;
} }
void WhenStmt::Inline(Inliner* inl) void WhenInfo::UpdateIDs(Reducer* c)
{ {
// Don't inline, since we currently don't correctly capture for ( auto& cp : *cl )
// the frames of closures. {
auto& cid = cp.Id();
if ( when_new_locals.count(cid.get()) == 0 )
cp.SetID(c->UpdateID(cid));
}
for ( auto& l : when_expr_locals )
l = c->UpdateID(l);
}
StmtPtr WhenStmt::Duplicate()
{
return SetSucc(new WhenStmt(new WhenInfo(wi)));
} }
bool WhenStmt::IsReduced(Reducer* c) const bool WhenStmt::IsReduced(Reducer* c) const
{ {
// We consider these always reduced because they're not if ( wi->HasUnreducedIDs(c) )
// candidates for any further optimization. return false;
return true;
if ( ! wi->Lambda()->IsReduced(c) )
return false;
if ( ! wi->TimeoutExpr() )
return true;
return wi->TimeoutExpr()->IsReduced(c);
}
StmtPtr WhenStmt::DoReduce(Reducer* c)
{
if ( ! c->Optimizing() )
{
wi->UpdateIDs(c);
(void)wi->Lambda()->ReduceToSingletons(c);
}
auto e = wi->TimeoutExpr();
if ( ! e )
return ThisPtr();
if ( c->Optimizing() )
wi->SetTimeoutExpr(c->OptExpr(e));
else if ( ! e->IsSingleton(c) )
{
StmtPtr red_e_stmt;
auto new_e = e->ReduceToSingleton(c, red_e_stmt);
wi->SetTimeoutExpr(new_e);
if ( red_e_stmt )
{
auto s = make_intrusive<StmtList>(red_e_stmt, ThisPtr());
return TransformMe(s, c);
}
}
return ThisPtr();
} }
CatchReturnStmt::CatchReturnStmt(StmtPtr _block, NameExprPtr _ret_var) : Stmt(STMT_CATCH_RETURN) CatchReturnStmt::CatchReturnStmt(StmtPtr _block, NameExprPtr _ret_var) : Stmt(STMT_CATCH_RETURN)