diff --git a/src/Stmt.cc b/src/Stmt.cc index 0f6bd043f0..68d60184ae 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -2005,9 +2005,68 @@ WhenInfo::WhenInfo(ExprPtr arg_cond, FuncType::CaptureList* arg_cl, bool arg_is_ if ( ! cl ) cl = new zeek::FuncType::CaptureList; + BuildProfile(); + + // Create the internal lambda we'll use to manage the captures. + static int num_params = 0; // to ensure each is distinct + lambda_param_id = util::fmt("when-param-%d", ++num_params); + + auto param_list = new type_decl_list(); + auto count_t = base_type(TYPE_COUNT); + param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t)); + auto params = make_intrusive(param_list); + + lambda_ft = make_intrusive(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION); + + if ( ! is_return ) + lambda_ft->SetExpressionlessReturnOkay(true); + + lambda_ft->SetCaptures(*cl); + + auto id = current_scope()->GenerateTemporary("when-internal"); + id->SetType(lambda_ft); + push_scope(std::move(id), nullptr); + + param_id = install_ID(lambda_param_id.c_str(), current_module.c_str(), false, false); + param_id->SetType(count_t); + } + +WhenInfo::WhenInfo(const WhenInfo* orig) + { + if ( orig->cl ) + { + cl = new FuncType::CaptureList; + *cl = *orig->cl; + } + + cond = orig->OrigCond()->Duplicate(); + + // We don't duplicate these, as they'll be compiled separately. + s = orig->OrigBody(); + timeout_s = orig->OrigBody(); + + timeout = orig->OrigTimeout(); + if ( timeout ) + timeout = timeout->Duplicate(); + + lambda = cast_intrusive(orig->Lambda()->Duplicate()); + + is_return = orig->IsReturn(); + + BuildProfile(); + } + +WhenInfo::WhenInfo(bool arg_is_return) : is_return(arg_is_return) + { + cl = new zeek::FuncType::CaptureList; + BuildInvokeElems(); + } + +void WhenInfo::BuildProfile() + { ProfileFunc cond_pf(cond.get()); - when_expr_locals = cond_pf.Locals(); + auto when_expr_locals_set = cond_pf.Locals(); when_expr_globals = cond_pf.AllGlobals(); when_new_locals = cond_pf.WhenLocals(); @@ -2032,41 +2091,20 @@ WhenInfo::WhenInfo(ExprPtr arg_cond, FuncType::CaptureList* arg_cl, bool arg_is_ // In addition, don't treat them as external locals that // existed at the onset. - when_expr_locals.erase(wl); + when_expr_locals_set.erase(wl); } - // Create the internal lambda we'll use to manage the captures. - static int num_params = 0; // to ensure each is distinct - lambda_param_id = util::fmt("when-param-%d", ++num_params); - - auto param_list = new type_decl_list(); - auto count_t = base_type(TYPE_COUNT); - param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t)); - auto params = make_intrusive(param_list); - - lambda_ft = make_intrusive(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION); - - if ( ! is_return ) - lambda_ft->SetExpressionlessReturnOkay(true); - - auto id = current_scope()->GenerateTemporary("when-internal"); - id->SetType(lambda_ft); - push_scope(std::move(id), nullptr); - - auto arg_id = install_ID(lambda_param_id.c_str(), current_module.c_str(), false, false); - arg_id->SetType(count_t); - } - -WhenInfo::WhenInfo(bool arg_is_return) : is_return(arg_is_return) - { - cl = new zeek::FuncType::CaptureList; - BuildInvokeElems(); + for ( auto& w : when_expr_locals_set ) + { + // We need IDPtr versions of the locals so we can manipulate + // them during script optimization. + auto non_const_w = const_cast(w); + when_expr_locals.push_back({NewRef{}, non_const_w}); + } } void WhenInfo::Build(StmtPtr ws) { - lambda_ft->SetCaptures(*cl); - // Our general strategy is to construct a single lambda (so that // the values of captures are shared across all of its elements) // that's used for all three of the "when" components: condition, @@ -2086,10 +2124,13 @@ void WhenInfo::Build(StmtPtr ws) // First, the constants we'll need. BuildInvokeElems(); + if ( lambda ) + // No need to build the lambda. + return; + auto true_const = make_intrusive(val_mgr->True()); // Access to the parameter that selects which action we're doing. - auto param_id = lookup_ID(lambda_param_id.c_str(), current_module.c_str()); ASSERT(param_id); auto param = make_intrusive(param_id); @@ -2109,11 +2150,14 @@ void WhenInfo::Build(StmtPtr ws) auto shebang = make_intrusive(do_test, do_bodies, dummy_return); - auto ingredients = std::make_unique(current_scope(), shebang, + auto ingredients = std::make_shared(current_scope(), shebang, current_module); auto outer_ids = gather_outer_ids(pop_scope(), ingredients->Body()); - lambda = make_intrusive(std::move(ingredients), std::move(outer_ids), ws); + lambda = make_intrusive(std::move(ingredients), std::move(outer_ids), "", ws); + lambda->SetPrivateCaptures(when_new_locals); + + analyze_when_lambda(lambda.get()); } void WhenInfo::Instantiate(Frame* f) @@ -2205,8 +2249,7 @@ ValPtr WhenStmt::Exec(Frame* f, StmtFlowType& flow) std::vector local_aggrs; for ( auto& l : wi->WhenExprLocals() ) { - IDPtr l_ptr = {NewRef{}, const_cast(l)}; - auto v = f->GetElementByID(l_ptr); + auto v = f->GetElementByID(l); if ( v && v->Modifiable() ) local_aggrs.emplace_back(std::move(v)); } @@ -2226,6 +2269,23 @@ void WhenStmt::StmtDescribe(ODesc* d) const { Stmt::StmtDescribe(d); + auto cl = wi->Captures(); + if ( d->IsReadable() && ! cl->empty() ) + { + d->Add("["); + for ( auto& c : *cl ) + { + if ( &c != &(*cl)[0] ) + d->AddSP(","); + + if ( c.IsDeepCopy() ) + d->Add("copy "); + + d->Add(c.Id()->Name()); + } + d->Add("]"); + } + if ( d->IsReadable() ) d->Add("("); @@ -2267,32 +2327,13 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const TraversalCode tc = cb->PreStmt(this); HANDLE_TC_STMT_PRE(tc); - auto wl = wi->Lambda(); + tc = wi->Lambda()->Traverse(cb); + HANDLE_TC_STMT_PRE(tc); - if ( wl ) + auto e = wi->TimeoutExpr(); + if ( e ) { - tc = wl->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - } - - else - { - tc = wi->OrigCond()->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - - tc = wi->OrigBody()->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - - if ( wi->OrigTimeoutStmt() ) - { - tc = wi->OrigTimeoutStmt()->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - } - } - - if ( wi->OrigTimeout() ) - { - tc = wi->OrigTimeout()->Traverse(cb); + tc = e->Traverse(cb); HANDLE_TC_STMT_PRE(tc); } diff --git a/src/Stmt.h b/src/Stmt.h index 0e69a2e086..614d39d5bc 100644 --- a/src/Stmt.h +++ b/src/Stmt.h @@ -576,6 +576,9 @@ public: // Takes ownership of the CaptureList. WhenInfo(ExprPtr cond, FuncType::CaptureList* cl, bool is_return); + // Used for duplication to support inlining. + WhenInfo(const WhenInfo* orig); + // Constructor used by script optimization to create a stub. WhenInfo(bool is_return); @@ -616,6 +619,7 @@ public: StmtPtr TimeoutStmt(); ExprPtr TimeoutExpr() const { return timeout; } + void SetTimeoutExpr(ExprPtr e) { timeout = std::move(e); } double TimeoutVal(Frame* f); FuncType::CaptureList* Captures() { return cl; } @@ -625,10 +629,22 @@ public: // The locals and globals used in the conditional expression // (other than newly introduced locals), necessary for registering // the associated triggers for when their values change. - const IDSet& WhenExprLocals() const { return when_expr_locals; } - const IDSet& WhenExprGlobals() const { return when_expr_globals; } + const auto& WhenExprLocals() const { return when_expr_locals; } + const auto& WhenExprGlobals() const { return when_expr_globals; } + + // The locals introduced in the conditional expression. + const auto& WhenNewLocals() const { return when_new_locals; } + + // Used for script optimization when in-lining needs to revise + // identifiers. + bool HasUnreducedIDs(Reducer* c) const; + void UpdateIDs(Reducer* c); private: + // Profile the original AST elements to extract things like + // globals and locals used. + void BuildProfile(); + // Build those elements we'll need for invoking our lambda. void BuildInvokeElems(); @@ -640,8 +656,10 @@ private: bool is_return = false; - // The name of parameter passed to the lambda. + // The name of parameter passed to the lambda, and the corresponding + // identifier. std::string lambda_param_id; + IDPtr param_id; // The expression for constructing the lambda, and its type. LambdaExprPtr lambda; @@ -662,7 +680,7 @@ private: ConstExprPtr two_const; ConstExprPtr three_const; - IDSet when_expr_locals; + std::vector when_expr_locals; IDSet when_expr_globals; // Locals introduced via "local" in the "when" clause itself. @@ -685,7 +703,7 @@ public: StmtPtr TimeoutBody() const { return wi->TimeoutStmt(); } bool IsReturn() const { return wi->IsReturn(); } - const WhenInfo* Info() const { return wi; } + WhenInfo* Info() const { return wi; } void StmtDescribe(ODesc* d) const override; @@ -693,9 +711,9 @@ public: // Optimization-related: StmtPtr Duplicate() override; - void Inline(Inliner* inl) override; bool IsReduced(Reducer* c) const override; + StmtPtr DoReduce(Reducer* c) override; private: WhenInfo* wi; diff --git a/src/Var.cc b/src/Var.cc index 5fd1d38cf0..f125540cac 100644 --- a/src/Var.cc +++ b/src/Var.cc @@ -768,11 +768,10 @@ TraversalCode OuterIDBindingFinder::PreStmt(const Stmt* stmt) if ( stmt->Tag() != STMT_WHEN ) return TC_CONTINUE; - // The semantics of identifiers for the "when" statement are those - // of the lambda it's transformed into. - auto ws = static_cast(stmt); - ws->Info()->Lambda()->Traverse(this); + + for ( auto& cl : ws->Info()->WhenExprLocals() ) + outer_id_references.insert(const_cast(cl.get())); return TC_ABORTSTMT; } @@ -789,18 +788,19 @@ TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr) if ( expr->Tag() != EXPR_NAME ) return TC_CONTINUE; - auto* e = static_cast(expr); + auto e = static_cast(expr); + auto id = e->Id(); - if ( e->Id()->IsGlobal() ) + if ( id->IsGlobal() ) return TC_CONTINUE; for ( const auto& scope : scopes ) - if ( scope->Find(e->Id()->Name()) ) + if ( scope->Find(id->Name()) ) // Shadowing is not allowed, so if it's found at inner scope, it's // not something we have to worry about also being at outer scope. return TC_CONTINUE; - outer_id_references.insert(e->Id()); + outer_id_references.insert(id); return TC_CONTINUE; } diff --git a/src/script_opt/Expr.cc b/src/script_opt/Expr.cc index 8e8a3d49f9..dbf40a1991 100644 --- a/src/script_opt/Expr.cc +++ b/src/script_opt/Expr.cc @@ -2342,7 +2342,7 @@ ExprPtr CallExpr::Duplicate() auto func_type = func->GetType(); auto in_hook = func_type->AsFuncType()->Flavor() == FUNC_FLAVOR_HOOK; - return SetSucc(new CallExpr(func_d, args_d, in_hook)); + return SetSucc(new CallExpr(func_d, args_d, in_hook, in_when)); } ExprPtr CallExpr::Inline(Inliner* inl) diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index 9d49adf316..65570c8e31 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -939,34 +939,85 @@ StmtPtr AssertStmt::DoReduce(Reducer* c) return make_intrusive(); } -StmtPtr WhenStmt::Duplicate() +bool WhenInfo::HasUnreducedIDs(Reducer* c) const { - FuncType::CaptureList* cl_dup = nullptr; - - if ( wi->Captures() ) + for ( auto& cp : *cl ) { - cl_dup = new FuncType::CaptureList; - *cl_dup = *wi->Captures(); + auto cid = cp.Id(); + + if ( when_new_locals.count(cid.get()) == 0 && ! c->ID_IsReduced(cp.Id()) ) + return true; } - auto new_wi = new WhenInfo(Cond(), cl_dup, IsReturn()); - new_wi->AddBody(Body()); - new_wi->AddTimeout(TimeoutExpr(), TimeoutBody()); + for ( auto& l : when_expr_locals ) + if ( ! c->ID_IsReduced(l) ) + return true; - return SetSucc(new WhenStmt(wi)); + return false; } -void WhenStmt::Inline(Inliner* inl) +void WhenInfo::UpdateIDs(Reducer* c) { - // Don't inline, since we currently don't correctly capture - // the frames of closures. + for ( auto& cp : *cl ) + { + auto& cid = cp.Id(); + if ( when_new_locals.count(cid.get()) == 0 ) + cp.SetID(c->UpdateID(cid)); + } + + for ( auto& l : when_expr_locals ) + l = c->UpdateID(l); + } + +StmtPtr WhenStmt::Duplicate() + { + return SetSucc(new WhenStmt(new WhenInfo(wi))); } bool WhenStmt::IsReduced(Reducer* c) const { - // We consider these always reduced because they're not - // candidates for any further optimization. - return true; + if ( wi->HasUnreducedIDs(c) ) + return false; + + if ( ! wi->Lambda()->IsReduced(c) ) + return false; + + if ( ! wi->TimeoutExpr() ) + return true; + + return wi->TimeoutExpr()->IsReduced(c); + } + +StmtPtr WhenStmt::DoReduce(Reducer* c) + { + if ( ! c->Optimizing() ) + { + wi->UpdateIDs(c); + (void)wi->Lambda()->ReduceToSingletons(c); + } + + auto e = wi->TimeoutExpr(); + + if ( ! e ) + return ThisPtr(); + + if ( c->Optimizing() ) + wi->SetTimeoutExpr(c->OptExpr(e)); + + else if ( ! e->IsSingleton(c) ) + { + StmtPtr red_e_stmt; + auto new_e = e->ReduceToSingleton(c, red_e_stmt); + wi->SetTimeoutExpr(new_e); + + if ( red_e_stmt ) + { + auto s = make_intrusive(red_e_stmt, ThisPtr()); + return TransformMe(s, c); + } + } + + return ThisPtr(); } CatchReturnStmt::CatchReturnStmt(StmtPtr _block, NameExprPtr _ret_var) : Stmt(STMT_CATCH_RETURN)