From f895008c3439744832cea267f6db2d7937e127d7 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 14:50:35 -0800 Subject: [PATCH] captures for "when" statements update Triggers to IntrusivePtr's and simpler AST traversal introduce IDSet type, migrate associated "ID*" types to "const ID*" --- src/Expr.cc | 31 +++- src/Expr.h | 6 +- src/ID.h | 2 + src/Stmt.cc | 271 ++++++++++++++++++++++++++++++---- src/Stmt.h | 110 ++++++++++++-- src/Trigger.cc | 97 ++++++++---- src/Trigger.h | 50 +++++-- src/Val.h | 1 + src/Var.cc | 19 +++ src/parse.y | 113 ++++++++------ src/script_opt/CPP/Compile.h | 4 +- src/script_opt/CPP/GenFunc.cc | 2 +- src/script_opt/GenIDDefs.h | 2 +- src/script_opt/ProfileFunc.cc | 23 ++- src/script_opt/ProfileFunc.h | 47 +++--- src/script_opt/Reduce.h | 4 +- src/script_opt/Stmt.cc | 17 ++- src/script_opt/UseDefs.h | 4 +- src/script_opt/ZAM/AM-Opt.cc | 4 +- src/script_opt/ZAM/Compile.h | 10 +- src/script_opt/ZAM/Ops.in | 19 ++- src/script_opt/ZAM/Stmt.cc | 2 +- src/script_opt/ZAM/Vars.cc | 6 +- src/script_opt/ZAM/ZInst.h | 6 +- 24 files changed, 648 insertions(+), 202 deletions(-) diff --git a/src/Expr.cc b/src/Expr.cc index c291f288ae..b857ac10de 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -4588,7 +4588,8 @@ void CallExpr::ExprDescribe(ODesc* d) const args->Describe(d); } -LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList arg_outer_ids) +LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList arg_outer_ids, + StmtPtr when_parent) : Expr(EXPR_LAMBDA) { ingredients = std::move(arg_ing); @@ -4596,7 +4597,7 @@ LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList ar SetType(ingredients->id->GetType()); - CheckCaptures(); + CheckCaptures(when_parent); // Install a dummy version of the function globally for use only // when broker provides a closure. @@ -4641,7 +4642,7 @@ LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList ar id->SetConst(); } -void LambdaExpr::CheckCaptures() +void LambdaExpr::CheckCaptures(StmtPtr when_parent) { auto ft = type->AsFuncType(); const auto& captures = ft->GetCaptures(); @@ -4665,6 +4666,8 @@ void LambdaExpr::CheckCaptures() std::set outer_is_matched; std::set capture_is_matched; + auto desc = when_parent ? "\"when\" statement" : "lambda"; + for ( const auto& c : *captures ) { auto cid = c.id.get(); @@ -4677,7 +4680,11 @@ void LambdaExpr::CheckCaptures() if ( capture_is_matched.count(cid) > 0 ) { - ExprError(util::fmt("%s listed multiple times in capture", cid->Name())); + auto msg = util::fmt("%s listed multiple times in capture", cid->Name()); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); continue; } @@ -4692,13 +4699,25 @@ void LambdaExpr::CheckCaptures() for ( auto id : outer_ids ) if ( outer_is_matched.count(id) == 0 ) - ExprError(util::fmt("%s is used inside lambda but not captured", id->Name())); + { + auto msg = util::fmt("%s is used inside %s but not captured", id->Name(), desc); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); + } for ( const auto& c : *captures ) { auto cid = c.id.get(); if ( cid && capture_is_matched.count(cid) == 0 ) - ExprError(util::fmt("%s is captured but not used inside lambda", cid->Name())); + { + auto msg = util::fmt("%s is captured but not used inside %s", cid->Name(), desc); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); + } } } diff --git a/src/Expr.h b/src/Expr.h index e8863a59b3..24d63a7a36 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -137,6 +137,7 @@ using EventExprPtr = IntrusivePtr; using ExprPtr = IntrusivePtr; using NameExprPtr = IntrusivePtr; using RefExprPtr = IntrusivePtr; +using LambdaExprPtr = IntrusivePtr; class Stmt; using StmtPtr = IntrusivePtr; @@ -1428,7 +1429,8 @@ protected: class LambdaExpr final : public Expr { public: - LambdaExpr(std::unique_ptr ingredients, IDPList outer_ids); + LambdaExpr(std::unique_ptr ingredients, IDPList outer_ids, + StmtPtr when_parent = nullptr); const std::string& Name() const { return my_name; } const IDPList& OuterIDs() const { return outer_ids; } @@ -1449,7 +1451,7 @@ protected: void ExprDescribe(ODesc* d) const override; private: - void CheckCaptures(); + void CheckCaptures(StmtPtr when_parent); std::unique_ptr ingredients; diff --git a/src/ID.h b/src/ID.h index cb9f68997c..28b0a5689d 100644 --- a/src/ID.h +++ b/src/ID.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "zeek/Attr.h" @@ -55,6 +56,7 @@ enum IDScope class ID; using IDPtr = IntrusivePtr; +using IDSet = std::unordered_set; class IDOptInfo; diff --git a/src/Stmt.cc b/src/Stmt.cc index 87c47d0adc..7cbb13d7da 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -20,6 +20,7 @@ #include "zeek/Var.h" #include "zeek/logging/Manager.h" #include "zeek/logging/logging.bif.h" +#include "zeek/script_opt/ProfileFunc.h" #include "zeek/script_opt/StmtOptInfo.h" namespace zeek::detail @@ -1558,7 +1559,7 @@ ReturnStmt::ReturnStmt(ExprPtr arg_e) : ExprStmt(STMT_RETURN, std::move(arg_e)) else if ( ! e ) { - if ( ft->Flavor() != FUNC_FLAVOR_HOOK ) + if ( ft->Flavor() != FUNC_FLAVOR_HOOK && ! ft->ExpressionlessReturnOkay() ) Error("return statement needs expression"); } @@ -1794,45 +1795,257 @@ TraversalCode NullStmt::Traverse(TraversalCallback* cb) const HANDLE_TC_STMT_POST(tc); } -WhenStmt::WhenStmt(ExprPtr arg_cond, StmtPtr arg_s1, StmtPtr arg_s2, ExprPtr arg_timeout, - bool arg_is_return) - : Stmt(STMT_WHEN), cond(std::move(arg_cond)), s1(std::move(arg_s1)), s2(std::move(arg_s2)), - timeout(std::move(arg_timeout)), is_return(arg_is_return) +WhenInfo::WhenInfo(ExprPtr _cond, FuncType::CaptureList* _cl, bool _is_return) + : cond(std::move(_cond)), cl(_cl), is_return(_is_return) { - assert(cond); - assert(s1); + prior_vars = current_scope()->Vars(); + + ProfileFunc cond_pf(cond.get()); + + if ( ! cl ) + { + for ( auto& wl : cond_pf.WhenLocals() ) + prior_vars.erase(wl->Name()); + return; + } + + when_expr_locals = cond_pf.Locals(); + when_expr_globals = cond_pf.Globals(); + + // Make any when-locals part of our captures, if not already present, + // to enable sharing between the condition and the body/timeout code. + for ( auto& wl : cond_pf.WhenLocals() ) + { + bool is_present = false; + + for ( auto& c : *cl ) + if ( c.id == wl ) + { + is_present = true; + break; + } + + if ( ! is_present ) + { + IDPtr wl_ptr = {NewRef{}, const_cast(wl)}; + cl->emplace_back(FuncType::Capture{wl_ptr, false}); + } + + // In addition, don't treat them as external locals that + // existed at the onset. + when_expr_locals.erase(wl); + } + + // Create the internal lambda we'll use to manage the captures. + static int num_params = 0; // to ensure each is distinct + lambda_param_id = util::fmt("when-param-%d", ++num_params); + + auto param_list = new type_decl_list(); + auto count_t = base_type(TYPE_COUNT); + param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t)); + auto params = make_intrusive(param_list); + + auto ft = make_intrusive(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION); + ft->SetCaptures(*cl); + + if ( ! is_return ) + ft->SetExpressionlessReturnOkay(true); + + auto id = current_scope()->GenerateTemporary("when-internal"); + + // This begin_func will be completed by WhenInfo::Build(). + begin_func(id, current_module.c_str(), FUNC_FLAVOR_FUNCTION, false, ft); + } + +void WhenInfo::Build(StmtPtr ws) + { + if ( ! cl ) + { + // Old-style semantics. + auto locals = when_expr_locals; + + ProfileFunc cond_pf(cond.get()); + for ( auto& bl : cond_pf.Locals() ) + if ( prior_vars.count(bl->Name()) > 0 ) + locals.insert(bl); + + ProfileFunc body_pf(s.get()); + for ( auto& bl : body_pf.Locals() ) + if ( prior_vars.count(bl->Name()) > 0 ) + locals.insert(bl); + + if ( timeout_s ) + { + ProfileFunc to_pf(timeout_s.get()); + for ( auto& tl : to_pf.Locals() ) + if ( prior_vars.count(tl->Name()) > 0 ) + locals.insert(tl); + } + + if ( ! locals.empty() ) + { + std::string vars; + for ( auto& l : locals ) + { + if ( ! vars.empty() ) + vars += ", "; + vars += l->Name(); + } + + std::string msg = util::fmt("\"when\" statement referring to locals without an " + "explicit [] capture is deprecated: %s", + vars.c_str()); + ws->Warn(msg.c_str()); + } + + return; + } + + // Our general strategy is to construct a single lambda (so that + // the values of captures are shared across all of its elements) + // that's used for all three of the "when" components: condition, + // body, and timeout body. The idea is that the lambda is passed + // a single argument that specifies the particular functionality + // to execute (1 = condition, 2 = body, 3 = timeout). It gets tricky + // in that the condition needs to return a boolean, whereas the body + // and timeout *might* return a value (for "return when") constructs, + // or might not (for vanilla "when"). We address that issue by + // (1) making the return type be "any", and (2) introducing elsehwere + // the notion of functions marked as being allowed to have bare + // returns (no associated expression) even though they have a return + // type (to deal with the vanilla "when" case). + + // Build the AST elements of the lambda. + + // First, the constants we'll need. + auto true_const = make_intrusive(val_mgr->True()); + auto one_const = make_intrusive(val_mgr->Count(1)); + auto two_const = make_intrusive(val_mgr->Count(2)); + auto three_const = make_intrusive(val_mgr->Count(3)); + + invoke_cond = make_intrusive(one_const); + invoke_s = make_intrusive(two_const); + invoke_timeout = make_intrusive(three_const); + + // Access to the parameter that selects which action we're doing. + auto param_id = lookup_ID(lambda_param_id.c_str(), current_module.c_str()); + ASSERT(param_id); + auto param = make_intrusive(param_id); + + // Expressions for testing for the latter constants. + auto one_test = make_intrusive(EXPR_EQ, param, one_const); + auto two_test = make_intrusive(EXPR_EQ, param, two_const); + + auto empty = make_intrusive(); + + auto test_cond = make_intrusive(cond); + auto do_test = make_intrusive(one_test, test_cond, empty); + + auto else_branch = timeout_s ? timeout_s : empty; + + auto do_bodies = make_intrusive(two_test, s, else_branch); + auto dummy_return = make_intrusive(true_const); + + auto shebang = make_intrusive(do_test, do_bodies, dummy_return); + + auto ingredients = std::make_unique(current_scope(), shebang); + auto outer_ids = gather_outer_ids(pop_scope(), ingredients->body); + + lambda = make_intrusive(std::move(ingredients), std::move(outer_ids), ws); + } + +void WhenInfo::Instantiate(Frame* f) + { + if ( cl ) + curr_lambda = make_intrusive(lambda->Eval(f)); + } + +ExprPtr WhenInfo::Cond() + { + if ( ! curr_lambda ) + return cond; + + return make_intrusive(curr_lambda, invoke_cond); + } + +StmtPtr WhenInfo::WhenBody() + { + if ( ! curr_lambda ) + return s; + + auto invoke = make_intrusive(curr_lambda, invoke_s); + return make_intrusive(invoke, true); + } + +StmtPtr WhenInfo::TimeoutStmt() + { + if ( ! curr_lambda ) + return timeout_s; + + auto invoke = make_intrusive(curr_lambda, invoke_timeout); + return make_intrusive(invoke, true); + } + +WhenStmt::WhenStmt(WhenInfo* _wi) : Stmt(STMT_WHEN), wi(_wi) + { + wi->Build(ThisPtr()); + + auto cond = wi->Cond(); if ( ! cond->IsError() && ! IsBool(cond->GetType()->Tag()) ) cond->Error("conditional in test must be boolean"); - if ( timeout ) + auto te = wi->TimeoutExpr(); + + if ( te ) { - if ( timeout->IsError() ) + if ( te->IsError() ) return; - TypeTag bt = timeout->GetType()->Tag(); + TypeTag bt = te->GetType()->Tag(); if ( bt != TYPE_TIME && bt != TYPE_INTERVAL ) - cond->Error("when timeout requires a time or time interval"); + te->Error("when timeout requires a time or time interval"); } } -WhenStmt::~WhenStmt() = default; +WhenStmt::~WhenStmt() + { + delete wi; + } ValPtr WhenStmt::Exec(Frame* f, StmtFlowType& flow) { RegisterAccess(); flow = FLOW_NEXT; - // The new trigger object will take care of its own deletion. - new trigger::Trigger(IntrusivePtr{cond}.release(), IntrusivePtr{s1}.release(), - IntrusivePtr{s2}.release(), IntrusivePtr{timeout}.release(), f, is_return, - location); + wi->Instantiate(f); + + if ( wi->Captures() ) + { + std::vector local_aggrs; + for ( auto& l : wi->WhenExprLocals() ) + { + IDPtr l_ptr = {NewRef{}, const_cast(l)}; + auto v = f->GetElementByID(l_ptr); + if ( v && v->Modifiable() ) + local_aggrs.emplace_back(std::move(v)); + } + + new trigger::Trigger(wi, wi->WhenExprGlobals(), local_aggrs, f, location); + } + + else + // The new trigger object will take care of its own deletion. + new trigger::Trigger(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), wi->TimeoutExpr(), f, + wi->IsReturn(), location); + return nullptr; } bool WhenStmt::IsPure() const { - return cond->IsPure() && s1->IsPure() && (! s2 || s2->IsPure()); + return wi->Cond()->IsPure() && wi->WhenBody()->IsPure() && + (! wi->TimeoutStmt() || wi->TimeoutStmt()->IsPure()); } void WhenStmt::StmtDescribe(ODesc* d) const @@ -1842,33 +2055,33 @@ void WhenStmt::StmtDescribe(ODesc* d) const if ( d->IsReadable() ) d->Add("("); - cond->Describe(d); + wi->Cond()->Describe(d); if ( d->IsReadable() ) d->Add(")"); d->SP(); d->PushIndent(); - s1->AccessStats(d); - s1->Describe(d); + wi->WhenBody()->AccessStats(d); + wi->WhenBody()->Describe(d); d->PopIndent(); - if ( s2 ) + if ( wi->TimeoutStmt() ) { if ( d->IsReadable() ) { d->SP(); d->Add("timeout"); d->SP(); - timeout->Describe(d); + wi->TimeoutExpr()->Describe(d); d->SP(); d->PushIndent(); - s2->AccessStats(d); - s2->Describe(d); + wi->TimeoutStmt()->AccessStats(d); + wi->TimeoutStmt()->Describe(d); d->PopIndent(); } else - s2->Describe(d); + wi->TimeoutStmt()->Describe(d); } } @@ -1877,15 +2090,15 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const TraversalCode tc = cb->PreStmt(this); HANDLE_TC_STMT_PRE(tc); - tc = cond->Traverse(cb); + tc = wi->Cond()->Traverse(cb); HANDLE_TC_STMT_PRE(tc); - tc = s1->Traverse(cb); + tc = wi->WhenBody()->Traverse(cb); HANDLE_TC_STMT_PRE(tc); - if ( s2 ) + if ( wi->TimeoutStmt() ) { - tc = s2->Traverse(cb); + tc = wi->TimeoutStmt()->Traverse(cb); HANDLE_TC_STMT_PRE(tc); } diff --git a/src/Stmt.h b/src/Stmt.h index b6161f2fd3..97038a2229 100644 --- a/src/Stmt.h +++ b/src/Stmt.h @@ -5,8 +5,10 @@ // Zeek statements. #include "zeek/Dict.h" +#include "zeek/Expr.h" #include "zeek/ID.h" #include "zeek/StmtBase.h" +#include "zeek/Type.h" #include "zeek/ZeekList.h" namespace zeek::detail @@ -442,7 +444,7 @@ public: // Optimization-related: StmtPtr Duplicate() override; - // Constructor used for duplication, when we've already done + // Constructor used internally, for when we've already done // all of the type-checking. ReturnStmt(ExprPtr e, bool ignored); @@ -535,21 +537,105 @@ public: StmtPtr Duplicate() override { return SetSucc(new NullStmt()); } }; +// A helper class for tracking all of the information associated with +// a "when" statement, and constructing the necessary components in support +// of lambda-style captures. +class WhenInfo + { +public: + // Takes ownership of the CaptureList, which if nil signifies + // old-style frame semantics. + WhenInfo(ExprPtr _cond, FuncType::CaptureList* _cl, bool _is_return); + ~WhenInfo() { delete cl; } + + void AddBody(StmtPtr _s) { s = std::move(_s); } + + void AddTimeout(ExprPtr _timeout, StmtPtr _timeout_s) + { + timeout = std::move(_timeout); + timeout_s = std::move(_timeout_s); + } + + // Complete construction of the associated internals, including + // the (complex) lambda used to access the different elements of + // the statement. + void Build(StmtPtr ws); + + // Instantiate a new instance. + void Instantiate(Frame* f); + + // For old-style semantics, the following simply return the + // individual "when" components. For capture semantics, however, + // these instead return different invocations of a lambda that + // manages the captures. + ExprPtr Cond(); + StmtPtr WhenBody(); + + ExprPtr TimeoutExpr() { return timeout; } + StmtPtr TimeoutStmt(); + + FuncType::CaptureList* Captures() { return cl; } + + bool IsReturn() const { return is_return; } + + const LambdaExprPtr& Lambda() const { return lambda; } + + // The locals and globals used in the conditional expression + // (other than newly introduced locals), necessary for registering + // the associated triggers for when their values change. + const IDSet& WhenExprLocals() const { return when_expr_locals; } + const IDSet& WhenExprGlobals() const { return when_expr_globals; } + +private: + ExprPtr cond; + StmtPtr s; + ExprPtr timeout; + StmtPtr timeout_s; + FuncType::CaptureList* cl; + + bool is_return = false; + + // The name of parameter passed ot the lambda. + std::string lambda_param_id; + + // The expression for constructing the lambda. + LambdaExprPtr lambda; + + // The current instance of the lambda. Created by Instantiate(), + // for immediate use via calls to Cond() etc. + ConstExprPtr curr_lambda; + + // Arguments to use when calling the lambda to either evaluate + // the conditional, or execute the body or the timeout statement. + ListExprPtr invoke_cond; + ListExprPtr invoke_s; + ListExprPtr invoke_timeout; + + IDSet when_expr_locals; + IDSet when_expr_globals; + + // Used for identifying deprecated instances. Holds all of the local + // variables in the scope prior to parsing the "when" statement. + std::map> prior_vars; + }; + class WhenStmt final : public Stmt { public: - // s2 is null if no timeout block given. - WhenStmt(ExprPtr cond, StmtPtr s1, StmtPtr s2, ExprPtr timeout, bool is_return); + // The constructor takes ownership of the WhenInfo object. + WhenStmt(WhenInfo* wi); ~WhenStmt() override; ValPtr Exec(Frame* f, StmtFlowType& flow) override; bool IsPure() const override; - const Expr* Cond() const { return cond.get(); } - const Stmt* Body() const { return s1.get(); } - const Expr* TimeoutExpr() const { return timeout.get(); } - const Stmt* TimeoutBody() const { return s2.get(); } - bool IsReturn() const { return is_return; } + ExprPtr Cond() const { return wi->Cond(); } + StmtPtr Body() const { return wi->WhenBody(); } + ExprPtr TimeoutExpr() const { return wi->TimeoutExpr(); } + StmtPtr TimeoutBody() const { return wi->TimeoutStmt(); } + bool IsReturn() const { return wi->IsReturn(); } + + const WhenInfo* Info() const { return wi; } void StmtDescribe(ODesc* d) const override; @@ -561,12 +647,8 @@ public: bool IsReduced(Reducer* c) const override; -protected: - ExprPtr cond; - StmtPtr s1; - StmtPtr s2; - ExprPtr timeout; - bool is_return; +private: + WhenInfo* wi; }; // Internal statement used for inlining. Executes a block and stops diff --git a/src/Trigger.cc b/src/Trigger.cc index ba9825071f..9d4e2ea121 100644 --- a/src/Trigger.cc +++ b/src/Trigger.cc @@ -23,21 +23,19 @@ using namespace zeek::detail::trigger; namespace zeek::detail::trigger { +// Used to extract the globals and locals seen in a trigger expression. class TriggerTraversalCallback : public TraversalCallback { public: - TriggerTraversalCallback(Trigger* arg_trigger) + TriggerTraversalCallback(IDSet& _globals, IDSet& _locals) : globals(_globals), locals(_locals) { - Ref(arg_trigger); - trigger = arg_trigger; } - ~TriggerTraversalCallback() { Unref(trigger); } - virtual TraversalCode PreExpr(const Expr*) override; private: - Trigger* trigger; + IDSet& globals; + IDSet& locals; }; TraversalCode trigger::TriggerTraversalCallback::PreExpr(const Expr* expr) @@ -50,14 +48,12 @@ TraversalCode trigger::TriggerTraversalCallback::PreExpr(const Expr* expr) case EXPR_NAME: { const auto* e = static_cast(expr); - if ( e->Id()->IsGlobal() ) - trigger->Register(e->Id()); + auto id = e->Id(); - Val* v = e->Id()->GetVal().get(); - - if ( v && v->Modifiable() ) - trigger->Register(v); - break; + if ( id->IsGlobal() ) + globals.insert(id); + else + locals.insert(id); }; default: @@ -103,10 +99,35 @@ protected: double time; }; -Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeout_expr, +Trigger::Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, ExprPtr timeout_expr, Frame* frame, bool is_return, const Location* location) { - timeout_value = -1; + GetTimeout(timeout_expr); + Init(cond, body, timeout_stmts, frame, is_return, location); + } + +Trigger::Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, double timeout, Frame* frame, + bool is_return, const Location* location) + { + timeout_value = timeout; + Init(cond, body, timeout_stmts, frame, is_return, location); + } + +Trigger::Trigger(WhenInfo* wi, const IDSet& _globals, std::vector _local_aggrs, Frame* f, + const Location* loc) + { + globals = _globals; + local_aggrs = std::move(_local_aggrs); + have_trigger_elems = true; + + GetTimeout(wi->TimeoutExpr()); + + Init(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), f, wi->IsReturn(), loc); + } + +void Trigger::GetTimeout(const ExprPtr& timeout_expr) + { + timeout_value = -1.0; if ( timeout_expr ) { @@ -123,24 +144,14 @@ Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeou if ( timeout_val ) timeout_value = timeout_val->AsInterval(); } - - Init(cond, body, timeout_stmts, frame, is_return, location); } -Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, double timeout, Frame* frame, - bool is_return, const Location* location) - { - timeout_value = timeout; - Init(cond, body, timeout_stmts, frame, is_return, location); - } - -void Trigger::Init(const Expr* arg_cond, Stmt* arg_body, Stmt* arg_timeout_stmts, Frame* arg_frame, +void Trigger::Init(ExprPtr arg_cond, StmtPtr arg_body, StmtPtr arg_timeout_stmts, Frame* arg_frame, bool arg_is_return, const Location* arg_location) { cond = arg_cond; body = arg_body; timeout_stmts = arg_timeout_stmts; - frame = arg_frame->Clone(); timer = nullptr; delayed = false; disabled = false; @@ -148,6 +159,11 @@ void Trigger::Init(const Expr* arg_cond, Stmt* arg_body, Stmt* arg_timeout_stmts is_return = arg_is_return; location = arg_location; + if ( arg_frame ) + frame = arg_frame->Clone(); + else + frame = nullptr; + DBG_LOG(DBG_NOTIFIERS, "%s: instantiating", Name()); if ( is_return ) @@ -217,8 +233,30 @@ void Trigger::ReInit(std::vector index_expr_results) { assert(! disabled); UnregisterAll(); - TriggerTraversalCallback cb(this); - cond->Traverse(&cb); + + if ( ! have_trigger_elems ) + { + TriggerTraversalCallback cb(globals, locals); + cond->Traverse(&cb); + have_trigger_elems = true; + } + + for ( auto g : globals ) + { + Register(g); + + auto& v = g->GetVal(); + if ( v && v->Modifiable() ) + Register(v.get()); + } + + for ( auto l : locals ) + { + ASSERT(! l->GetVal()); + } + + for ( auto& av : local_aggrs ) + Register(av.get()); for ( const auto& v : index_expr_results ) Register(v.get()); @@ -390,9 +428,10 @@ void Trigger::Timeout() Unref(this); } -void Trigger::Register(ID* id) +void Trigger::Register(const ID* const_id) { assert(! disabled); + ID* id = const_cast(const_id); notifier::detail::registry.Register(id, this); Ref(id); diff --git a/src/Trigger.h b/src/Trigger.h index da928c6a89..6810a7fccd 100644 --- a/src/Trigger.h +++ b/src/Trigger.h @@ -4,6 +4,7 @@ #include #include +#include "zeek/ID.h" #include "zeek/IntrusivePtr.h" #include "zeek/Notifier.h" #include "zeek/Obj.h" @@ -16,6 +17,8 @@ namespace zeek class ODesc; class Val; +using ValPtr = IntrusivePtr; + namespace detail { @@ -24,6 +27,9 @@ class Stmt; class Expr; class CallExpr; class ID; +class WhenInfo; + +using StmtPtr = IntrusivePtr; namespace trigger { @@ -41,10 +47,18 @@ public: // instantiation. Note that if the condition is already true, the // statements are executed immediately and the object is deleted // right away. - Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeout, Frame* f, + + // These first two constructors are for the deprecated deep-copy + // semantics. + Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, ExprPtr timeout, Frame* f, bool is_return, const Location* loc); - Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, double timeout, Frame* f, + Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, double timeout, Frame* f, bool is_return, const Location* loc); + + // Used for capture-list semantics. + Trigger(WhenInfo* wi, const IDSet& globals, std::vector local_aggrs, Frame* f, + const Location* loc); + ~Trigger() override; // Evaluates the condition. If true, executes the body and deletes @@ -96,22 +110,23 @@ public: const char* Name() const; private: - friend class TriggerTraversalCallback; friend class TriggerTimer; - void Init(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Frame* frame, bool is_return, + void GetTimeout(const ExprPtr& timeout_expr); + + void Init(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, Frame* frame, bool is_return, const Location* location); - void ReInit(std::vector> index_expr_results); + void ReInit(std::vector index_expr_results); - void Register(ID* id); + void Register(const ID* id); void Register(Val* val); void UnregisterAll(); - const Expr* cond; - Stmt* body; - Stmt* timeout_stmts; - Expr* timeout; + ExprPtr cond; + StmtPtr body; + StmtPtr timeout_stmts; + ExprPtr timeout; double timeout_value; Frame* frame; bool is_return; @@ -123,14 +138,25 @@ private: bool delayed; // true if a function call is currently being delayed bool disabled; + // Globals and locals present in the when expression. + IDSet globals; + IDSet locals; // not needed, present only for matching deprecated logic + + // Tracks whether we've found the globals/locals, as the work only + // has to be done once. + bool have_trigger_elems = false; + + // Aggregate values seen in locals used in the trigger condition, + // so we can detect changes in them that affect whether the condition + // holds. + std::vector local_aggrs; + std::vector> objs; using ValCache = std::map; ValCache cache; }; -using TriggerPtr = IntrusivePtr; - class Manager final : public iosource::IOSource { public: diff --git a/src/Val.h b/src/Val.h index 6ece8c90de..d40bf6db77 100644 --- a/src/Val.h +++ b/src/Val.h @@ -82,6 +82,7 @@ class TypeVal; using AddrValPtr = IntrusivePtr; using EnumValPtr = IntrusivePtr; +using FuncValPtr = IntrusivePtr; using ListValPtr = IntrusivePtr; using PortValPtr = IntrusivePtr; using RecordValPtr = IntrusivePtr; diff --git a/src/Var.cc b/src/Var.cc index 2cf367228e..39e03be262 100644 --- a/src/Var.cc +++ b/src/Var.cc @@ -680,6 +680,7 @@ class OuterIDBindingFinder : public TraversalCallback public: OuterIDBindingFinder(ScopePtr s) { scopes.emplace_back(s); } + TraversalCode PreStmt(const Stmt*) override; TraversalCode PreExpr(const Expr*) override; TraversalCode PostExpr(const Expr*) override; @@ -687,6 +688,24 @@ public: std::unordered_set outer_id_references; }; +TraversalCode OuterIDBindingFinder::PreStmt(const Stmt* stmt) + { + if ( stmt->Tag() != STMT_WHEN ) + return TC_CONTINUE; + + auto ws = static_cast(stmt); + auto lambda = ws->Info()->Lambda(); + + if ( ! lambda ) + // Old-style semantics. + return TC_CONTINUE; + + // The semantics of identifiers for the "when" statement are those + // of the lambda it's transformed into. + lambda->Traverse(this); + return TC_ABORTSTMT; + } + TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr) { if ( expr->Tag() == EXPR_LAMBDA ) diff --git a/src/parse.y b/src/parse.y index f86da73ccc..b52cf95a93 100644 --- a/src/parse.y +++ b/src/parse.y @@ -5,7 +5,7 @@ // Switching parser table type fixes ambiguity problems. %define lr.type ielr -%expect 141 +%expect 140 %token TOK_ADD TOK_ADD_TO TOK_ADDR TOK_ANY %token TOK_ATENDIF TOK_ATELSE TOK_ATIF TOK_ATIFDEF TOK_ATIFNDEF @@ -52,7 +52,7 @@ %left '$' '[' ']' '(' ')' TOK_HAS_FIELD TOK_HAS_ATTR %nonassoc TOK_AS TOK_IS -%type opt_no_test opt_no_test_block TOK_PATTERN_END opt_deep +%type opt_no_test opt_no_test_block TOK_PATTERN_END opt_deep when_flavor %type TOK_ID TOK_PATTERN_TEXT %type local_id global_id def_global_id event_id global_or_event_id resolve_id begin_lambda case_type %type local_id_list case_type_list @@ -73,7 +73,8 @@ %type attr %type attr_list opt_attr %type capture -%type capture_list opt_captures +%type capture_list opt_captures when_captures +%type when_head when_start when_clause %{ #include @@ -309,7 +310,8 @@ static StmtPtr build_local(ID* id, Type* t, InitClass ic, Expr* e, std::vector* attr_l; zeek::detail::AttrTag attrtag; zeek::FuncType::Capture* capture; - std::vector* captures; + zeek::FuncType::CaptureList* captures; + zeek::detail::WhenInfo* when_clause; } %% @@ -348,6 +350,54 @@ opt_expr: { $$ = 0; } ; +when_clause: + when_head TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' + { + set_location(@1, @7); + $1->AddTimeout({AdoptRef{}, $3}, {AdoptRef{}, $6}); + if ( $5 ) + script_coverage_mgr.DecIgnoreDepth(); + } + | + when_head + ; + +when_head: + when_start stmt + { + set_location(@1, @2); + $1->AddBody({AdoptRef{}, $2}); + } + ; + +when_start: + when_flavor '[' when_captures ']' '(' when_condition ')' + { + set_location(@1, @7); + $$ = new WhenInfo({AdoptRef{}, $6}, $3, $1); + } + + | when_flavor '(' when_condition ')' + { + set_location(@1, @4); + $$ = new WhenInfo({AdoptRef{}, $3}, nullptr, $1); + } + ; + +when_flavor: + TOK_RETURN TOK_WHEN + { $$ = true; } + | + TOK_WHEN + { $$ = false; } + ; + +when_captures: + capture_list + | + { $$ = new zeek::FuncType::CaptureList; } + ; + when_condition: { ++in_when_cond; } expr { --in_when_cond; } { $$ = $2; } @@ -1385,15 +1435,7 @@ begin_lambda: if ( $1 ) { - captures = FuncType::CaptureList{}; - captures->reserve($1->size()); - - for ( auto c : *$1 ) - { - captures->emplace_back(*c); - delete c; - } - + captures = *$1; delete $1; } @@ -1411,11 +1453,15 @@ opt_captures: capture_list: capture_list ',' capture - { $1->push_back($3); } + { + $1->push_back(*$3); + delete $3; + } | capture { - $$ = new std::vector; - $$->push_back($1); + $$ = new zeek::FuncType::CaptureList; + $$->push_back(*$1); + delete $1; } ; @@ -1423,8 +1469,7 @@ capture: opt_deep TOK_ID { set_location(@2); - auto id = lookup_ID($2, - current_module.c_str()); + auto id = lookup_ID($2, current_module.c_str()); if ( ! id ) reporter->Error("no such local identifier: %s", $2); @@ -1722,37 +1767,9 @@ stmt: $$ = build_local($2, $3, $4, $5, $6, VAR_CONST, ! $8).release(); } - | TOK_WHEN '(' when_condition ')' stmt + | when_clause { - set_location(@3, @5); - $$ = new WhenStmt({AdoptRef{}, $3}, {AdoptRef{}, $5}, - nullptr, nullptr, false); - } - - | TOK_WHEN '(' when_condition ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' - { - set_location(@3, @9); - $$ = new WhenStmt({AdoptRef{}, $3}, {AdoptRef{}, $5}, - {AdoptRef{}, $10}, {AdoptRef{}, $7}, false); - if ( $9 ) - script_coverage_mgr.DecIgnoreDepth(); - } - - - | TOK_RETURN TOK_WHEN '(' when_condition ')' stmt - { - set_location(@4, @6); - $$ = new WhenStmt({AdoptRef{}, $4}, {AdoptRef{}, $6}, nullptr, - nullptr, true); - } - - | TOK_RETURN TOK_WHEN '(' when_condition ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' - { - set_location(@4, @10); - $$ = new WhenStmt({AdoptRef{}, $4}, {AdoptRef{}, $6}, - {AdoptRef{}, $11}, {AdoptRef{}, $8}, true); - if ( $10 ) - script_coverage_mgr.DecIgnoreDepth(); + $$ = new WhenStmt($1); } | index_slice '=' expr ';' opt_no_test diff --git a/src/script_opt/CPP/Compile.h b/src/script_opt/CPP/Compile.h index 11d8808a75..f2fa2f3874 100644 --- a/src/script_opt/CPP/Compile.h +++ b/src/script_opt/CPP/Compile.h @@ -435,7 +435,7 @@ private: std::unordered_map events; // Globals that correspond to variables, not functions. - std::unordered_set global_vars; + IDSet global_vars; // // End of methods related to script/C++ variables. @@ -539,7 +539,7 @@ private: std::unordered_map lambda_names; // The function's parameters. Tracked so we don't re-declare them. - std::unordered_set params; + IDSet params; // Whether we're compiling a hook. bool in_hook = false; diff --git a/src/script_opt/CPP/GenFunc.cc b/src/script_opt/CPP/GenFunc.cc index 673e6def4a..f4071c7deb 100644 --- a/src/script_opt/CPP/GenFunc.cc +++ b/src/script_opt/CPP/GenFunc.cc @@ -149,7 +149,7 @@ void CPPCompile::InitializeEvents(const ProfileFunc* pf) void CPPCompile::DeclareLocals(const ProfileFunc* pf, const IDPList* lambda_ids) { // It's handy to have a set of the lambda captures rather than a list. - unordered_set lambda_set; + IDSet lambda_set; if ( lambda_ids ) for ( auto li : *lambda_ids ) lambda_set.insert(li); diff --git a/src/script_opt/GenIDDefs.h b/src/script_opt/GenIDDefs.h index 41322f498f..a18c01928e 100644 --- a/src/script_opt/GenIDDefs.h +++ b/src/script_opt/GenIDDefs.h @@ -105,7 +105,7 @@ private: // the front entry tracks identifiers at the outermost // (non-confluence) scope. Thus, to index it for a given // confluence block i, we need to use i+1. - std::vector> modified_IDs; + std::vector modified_IDs; // If non-zero, indicates we should suspend any generation // of usage errors. A counter rather than a boolean because diff --git a/src/script_opt/ProfileFunc.cc b/src/script_opt/ProfileFunc.cc index 096604df3f..234f09af4f 100644 --- a/src/script_opt/ProfileFunc.cc +++ b/src/script_opt/ProfileFunc.cc @@ -270,13 +270,26 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) case EXPR_REMOVE_FROM: case EXPR_ASSIGN: { - if ( e->GetOp1()->Tag() == EXPR_REF ) + if ( e->GetOp1()->Tag() != EXPR_REF ) + // this isn't a direct assignment + break; + + auto lhs = e->GetOp1()->GetOp1(); + if ( lhs->Tag() != EXPR_NAME ) + break; + + auto id = lhs->AsNameExpr()->Id(); + TrackAssignment(id); + + if ( e->Tag() == EXPR_ASSIGN ) { - auto lhs = e->GetOp1()->GetOp1(); - if ( lhs->Tag() == EXPR_NAME ) - TrackAssignment(lhs->AsNameExpr()->Id()); + auto a_e = static_cast(e); + auto& av = a_e->AssignVal(); + if ( av ) + // This is a funky "local" assignment + // inside a when clause. + when_locals.insert(id); } - // else this isn't a direct assignment. break; } diff --git a/src/script_opt/ProfileFunc.h b/src/script_opt/ProfileFunc.h index 1b0a08e54c..ed839167b4 100644 --- a/src/script_opt/ProfileFunc.h +++ b/src/script_opt/ProfileFunc.h @@ -97,23 +97,23 @@ public: // See the comments for the associated member variables for each // of these accessors. - const std::unordered_set& Globals() const { return globals; } - const std::unordered_set& AllGlobals() const { return all_globals; } - const std::unordered_set& Locals() const { return locals; } - const std::unordered_set& Params() const { return params; } + const IDSet& Globals() const { return globals; } + const IDSet& AllGlobals() const { return all_globals; } + const IDSet& Locals() const { return locals; } + const IDSet& WhenLocals() const { return when_locals; } + const IDSet& Params() const { return params; } const std::unordered_map& Assignees() const { return assignees; } - const std::unordered_set& Inits() const { return inits; } + const IDSet& Inits() const { return inits; } const std::vector& Stmts() const { return stmts; } const std::vector& Exprs() const { return exprs; } const std::vector& Lambdas() const { return lambdas; } const std::vector& Constants() const { return constants; } - const std::unordered_set& UnorderedIdentifiers() const { return ids; } + const IDSet& UnorderedIdentifiers() const { return ids; } const std::vector& OrderedIdentifiers() const { return ordered_ids; } const std::unordered_set& UnorderedTypes() const { return types; } const std::vector& OrderedTypes() const { return ordered_types; } const std::unordered_set& ScriptCalls() const { return script_calls; } - const std::unordered_set& BiFGlobals() const { return BiF_globals; } - const std::unordered_set& WhenCalls() const { return when_calls; } + const IDSet& BiFGlobals() const { return BiF_globals; } const std::unordered_set& Events() const { return events; } const std::unordered_set& ConstructorAttrs() const { @@ -162,17 +162,20 @@ protected: // // Does *not* include globals solely seen as the function being // called in a call. - std::unordered_set globals; + IDSet globals; // Same, but also includes globals only seen as called functions. - std::unordered_set all_globals; + IDSet all_globals; // Locals seen in the function. - std::unordered_set locals; + IDSet locals; + + // Same, but for those declared in "when" expressions. + IDSet when_locals; // The function's parameters. Empty if our starting point was // profiling an expression. - std::unordered_set params; + IDSet params; // How many parameters the function has. The default value flags // that we started the profile with an expression rather than a @@ -187,7 +190,7 @@ protected: // Same for locals seen in initializations, so we can find, // for example, unused aggregates. - std::unordered_set inits; + IDSet inits; // Statements seen in the function. Does not include indirect // statements, such as those in lambda bodies. @@ -203,13 +206,13 @@ protected: std::vector lambdas; // If we're profiling a lambda function, this holds the captures. - std::unordered_set captures; + IDSet captures; // Constants seen in the function. std::vector constants; // Identifiers seen in the function. - std::unordered_set ids; + IDSet ids; // The same, but in a deterministic order. std::vector ordered_ids; @@ -226,7 +229,7 @@ protected: // Same for BiF's, though for them we record the corresponding global // rather than the BuiltinFunc*. - std::unordered_set BiF_globals; + IDSet BiF_globals; // Script functions appearing in "when" clauses. std::unordered_set when_calls; @@ -283,13 +286,13 @@ public: // The following accessors provide a global profile across all of // the (non-skipped) functions in "funcs". See the comments for // the associated member variables for documentation. - const std::unordered_set& Globals() const { return globals; } - const std::unordered_set& AllGlobals() const { return all_globals; } + const IDSet& Globals() const { return globals; } + const IDSet& AllGlobals() const { return all_globals; } const std::unordered_set& Constants() const { return constants; } const std::vector& MainTypes() const { return main_types; } const std::vector& RepTypes() const { return rep_types; } const std::unordered_set& ScriptCalls() const { return script_calls; } - const std::unordered_set& BiFGlobals() const { return BiF_globals; } + const IDSet& BiFGlobals() const { return BiF_globals; } const std::unordered_set& Lambdas() const { return lambdas; } const std::unordered_set& Events() const { return events; } @@ -345,10 +348,10 @@ protected: // Globals seen across the functions, other than those solely seen // as the function being called in a call. - std::unordered_set globals; + IDSet globals; // Same, but also includes globals only seen as called functions. - std::unordered_set all_globals; + IDSet all_globals; // Constants seen across the functions. std::unordered_set constants; @@ -369,7 +372,7 @@ protected: std::unordered_set script_calls; // Same for BiF's. - std::unordered_set BiF_globals; + IDSet BiF_globals; // And for lambda's. std::unordered_set lambdas; diff --git a/src/script_opt/Reduce.h b/src/script_opt/Reduce.h index 853ab5e7ab..931bf745fd 100644 --- a/src/script_opt/Reduce.h +++ b/src/script_opt/Reduce.h @@ -229,7 +229,7 @@ protected: std::unordered_map> ids_to_temps; // Local variables created during reduction/optimization. - std::unordered_set new_locals; + IDSet new_locals; // Mapping of original identifiers to new locals. Used to // rename local variables when inlining. @@ -262,7 +262,7 @@ protected: // Tracks which (non-temporary) variables had constant // values used for constant propagation. - std::unordered_set constant_vars; + IDSet constant_vars; // Statement at which the current reduction started. StmtPtr reduction_root = nullptr; diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index 68732ffdd0..27a55c5658 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -910,12 +910,19 @@ StmtPtr InitStmt::DoReduce(Reducer* c) StmtPtr WhenStmt::Duplicate() { - auto cond_d = cond->Duplicate(); - auto s1_d = s1->Duplicate(); - auto s2_d = s2 ? s2->Duplicate() : nullptr; - auto timeout_d = timeout ? timeout->Duplicate() : nullptr; + FuncType::CaptureList* cl_dup = nullptr; - return SetSucc(new WhenStmt(cond_d, s1_d, s2_d, timeout_d, is_return)); + if ( wi->Captures() ) + { + cl_dup = new FuncType::CaptureList; + *cl_dup = *wi->Captures(); + } + + auto new_wi = new WhenInfo(Cond(), cl_dup, IsReturn()); + new_wi->AddBody(Body()); + new_wi->AddTimeout(TimeoutExpr(), TimeoutBody()); + + return SetSucc(new WhenStmt(wi)); } void WhenStmt::Inline(Inliner* inl) diff --git a/src/script_opt/UseDefs.h b/src/script_opt/UseDefs.h index ac30afde46..f453a6f721 100644 --- a/src/script_opt/UseDefs.h +++ b/src/script_opt/UseDefs.h @@ -33,7 +33,7 @@ public: void Add(const ID* id) { use_defs.insert(id); } void Remove(const ID* id) { use_defs.erase(id); } - const std::unordered_set& IterateOver() const { return use_defs; } + const IDSet& IterateOver() const { return use_defs; } void Dump() const; void DumpNL() const @@ -43,7 +43,7 @@ public: } protected: - std::unordered_set use_defs; + IDSet use_defs; }; class Reducer; diff --git a/src/script_opt/ZAM/AM-Opt.cc b/src/script_opt/ZAM/AM-Opt.cc index 5af22bee01..fa8b81ac57 100644 --- a/src/script_opt/ZAM/AM-Opt.cc +++ b/src/script_opt/ZAM/AM-Opt.cc @@ -656,7 +656,7 @@ void ZAMCompiler::ReMapInterpreterFrame() remapped_intrp_frame_sizes[func] = next_interp_slot; } -void ZAMCompiler::ReMapVar(ID* id, int slot, bro_uint_t inst) +void ZAMCompiler::ReMapVar(const ID* id, int slot, bro_uint_t inst) { // A greedy algorithm for this is to simply find the first suitable // frame slot. We do that with one twist: we also look for a @@ -832,7 +832,7 @@ void ZAMCompiler::ExtendLifetime(int slot, const ZInstI* inst) if ( inst_endings.count(inst) == 0 ) { - std::unordered_set denizens; + IDSet denizens; inst_endings[inst] = denizens; } diff --git a/src/script_opt/ZAM/Compile.h b/src/script_opt/ZAM/Compile.h index a62eaae08f..ec6cad51bb 100644 --- a/src/script_opt/ZAM/Compile.h +++ b/src/script_opt/ZAM/Compile.h @@ -349,10 +349,10 @@ private: bool IsUnused(const IDPtr& id, const Stmt* where) const; - void LoadParam(ID* id); - const ZAMStmt LoadGlobal(ID* id); + void LoadParam(const ID* id); + const ZAMStmt LoadGlobal(const ID* id); - int AddToFrame(ID*); + int AddToFrame(const ID*); int FrameSlot(const IDPtr& id) { return FrameSlot(id.get()); } int FrameSlot(const ID* id); @@ -420,7 +420,7 @@ private: // Computes the remapping for a variable currently in the given slot, // whose scope begins at the given instruction. - void ReMapVar(ID* id, int slot, bro_uint_t inst); + void ReMapVar(const ID* id, int slot, bro_uint_t inst); // Look to initialize the beginning of local lifetime based on slot // assignment at instruction inst. @@ -541,7 +541,7 @@ private: // A type for mapping an instruction to a set of locals associated // with it. - using AssociatedLocals = std::unordered_map>; + using AssociatedLocals = std::unordered_map; // Maps (live) instructions to which frame denizens begin their // lifetime via an initialization at that instruction, if any ... diff --git a/src/script_opt/ZAM/Ops.in b/src/script_opt/ZAM/Ops.in index 2ba97d526d..3a191aded3 100644 --- a/src/script_opt/ZAM/Ops.in +++ b/src/script_opt/ZAM/Ops.in @@ -1740,20 +1740,23 @@ eval (*tiv_ptr)[z.v1].Clear(); op When op1-read type VVVV -eval auto when_body = new ZAMResumption(this, z.v2); - auto timeout_body = new ZAMResumption(this, z.v3); - new trigger::Trigger(z.e, when_body, timeout_body, frame[z.v1].double_val, f, z.v4, z.loc); +eval auto when_body = make_intrusive(this, z.v2); + auto timeout_body = make_intrusive(this, z.v3); + ExprPtr when_cond = {NewRef{}, const_cast(z.e)}; + new trigger::Trigger(when_cond, when_body, timeout_body, frame[z.v1].double_val, f, z.v4, z.loc); op When type VVVC -eval auto when_body = new ZAMResumption(this, z.v1); - auto timeout_body = new ZAMResumption(this, z.v2); - new trigger::Trigger(z.e, when_body, timeout_body, z.c.double_val, f, z.v3, z.loc); +eval auto when_body = make_intrusive(this, z.v1); + auto timeout_body = make_intrusive(this, z.v2); + ExprPtr when_cond = {NewRef{}, const_cast(z.e)}; + new trigger::Trigger(when_cond, when_body, timeout_body, z.c.double_val, f, z.v3, z.loc); op When type VV -eval auto when_body = new ZAMResumption(this, z.v2); - new trigger::Trigger(z.e, when_body, nullptr, -1.0, f, z.v1, z.loc); +eval auto when_body = make_intrusive(this, z.v2); + ExprPtr when_cond = {NewRef{}, const_cast(z.e)}; + new trigger::Trigger(when_cond, when_body, nullptr, -1.0, f, z.v1, z.loc); op CheckAnyLen op1-read diff --git a/src/script_opt/ZAM/Stmt.cc b/src/script_opt/ZAM/Stmt.cc index f6af240c95..7a4a46b4dd 100644 --- a/src/script_opt/ZAM/Stmt.cc +++ b/src/script_opt/ZAM/Stmt.cc @@ -1115,7 +1115,7 @@ const ZAMStmt ZAMCompiler::CompileWhen(const WhenStmt* ws) z.v1 = is_return; } - z.e = cond; + z.e = cond.get(); auto when_eval = AddInst(z); diff --git a/src/script_opt/ZAM/Vars.cc b/src/script_opt/ZAM/Vars.cc index 56ca59318d..543b32e32d 100644 --- a/src/script_opt/ZAM/Vars.cc +++ b/src/script_opt/ZAM/Vars.cc @@ -24,7 +24,7 @@ bool ZAMCompiler::IsUnused(const IDPtr& id, const Stmt* where) const return ! usage || ! usage->HasID(id.get()); } -void ZAMCompiler::LoadParam(ID* id) +void ZAMCompiler::LoadParam(const ID* id) { if ( id->IsType() ) reporter->InternalError( @@ -45,7 +45,7 @@ void ZAMCompiler::LoadParam(ID* id) (void)AddInst(z); } -const ZAMStmt ZAMCompiler::LoadGlobal(ID* id) +const ZAMStmt ZAMCompiler::LoadGlobal(const ID* id) { ZOp op; @@ -69,7 +69,7 @@ const ZAMStmt ZAMCompiler::LoadGlobal(ID* id) return AddInst(z); } -int ZAMCompiler::AddToFrame(ID* id) +int ZAMCompiler::AddToFrame(const ID* id) { frame_layout1[id] = frame_sizeI; frame_denizens.push_back(id); diff --git a/src/script_opt/ZAM/ZInst.h b/src/script_opt/ZAM/ZInst.h index b1dc751ea1..f9fc548804 100644 --- a/src/script_opt/ZAM/ZInst.h +++ b/src/script_opt/ZAM/ZInst.h @@ -18,7 +18,7 @@ class Stmt; using AttributesPtr = IntrusivePtr; // Maps ZAM frame slots to associated identifiers. -using FrameMap = std::vector; +using FrameMap = std::vector; // Maps ZAM frame slots to information for sharing the slot across // multiple script variables. @@ -28,7 +28,7 @@ public: // The variables sharing the slot. ID's need to be non-const so we // can manipulate them, for example by changing their interpreter // frame offset. - std::vector ids; + std::vector ids; // A parallel vector, only used for fully compiled code, which // gives the names of the identifiers. When in use, the above @@ -402,7 +402,7 @@ public: TypePtr* types = nullptr; // Used for accessing function names. - ID* id_val = nullptr; + const ID* id_val = nullptr; // Whether the instruction can lead to globals changing. // Currently only needed by the optimizer, but convenient