From a85d92b2eedba2340caa84e70fef48d03aae2347 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 11:49:03 -0800 Subject: [PATCH 01/10] minor commenting clarifications --- doc | 2 +- src/Frame.h | 5 +++-- src/Var.cc | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/doc b/doc index 1263619ce3..0cebe39463 160000 --- a/doc +++ b/doc @@ -1 +1 @@ -Subproject commit 1263619ce3279415314355032b855206f7e3e632 +Subproject commit 0cebe394636728788b2f880d7d8ec661bf22824e diff --git a/src/Frame.h b/src/Frame.h index 395cb8009e..d67929fb2b 100644 --- a/src/Frame.h +++ b/src/Frame.h @@ -54,8 +54,9 @@ public: Frame(int size, const ScriptFunc* func, const zeek::Args* fn_args); /** - * Deletes the frame. Unrefs its trigger, the values that it - * contains and its closure if applicable. + * Deletes the frame. Unrefs its trigger (implicitly, since it's an + * IntrusivePtr), and the values that the frame contains and its + * closure if applicable. */ virtual ~Frame() override; diff --git a/src/Var.cc b/src/Var.cc index 5437b11c33..2cf367228e 100644 --- a/src/Var.cc +++ b/src/Var.cc @@ -722,6 +722,10 @@ TraversalCode OuterIDBindingFinder::PostExpr(const Expr* expr) return TC_CONTINUE; } +// The following is only used for debugging AST duplication. If activated, +// each AST is replaced with its duplicate. In the absence of a duplication +// error, this shouldn't change any semantics, so running the test suite +// with this variable set can find flaws in the duplication machinery. static bool duplicate_ASTs = getenv("ZEEK_DUPLICATE_ASTS"); void end_func(StmtPtr body, bool free_of_conditionals) From 6cb5ea6835c070499bf1fa72b272ca294626a7ec Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 11:50:01 -0800 Subject: [PATCH 02/10] removed some now-obsolete profiling functionality --- src/script_opt/ProfileFunc.cc | 12 ------------ src/script_opt/ProfileFunc.h | 4 ---- 2 files changed, 16 deletions(-) diff --git a/src/script_opt/ProfileFunc.cc b/src/script_opt/ProfileFunc.cc index b29bf068d4..096604df3f 100644 --- a/src/script_opt/ProfileFunc.cc +++ b/src/script_opt/ProfileFunc.cc @@ -130,15 +130,6 @@ TraversalCode ProfileFunc::PreStmt(const Stmt* s) case STMT_WHEN: ++num_when_stmts; - - in_when = true; - s->AsWhenStmt()->Cond()->Traverse(this); - in_when = false; - - // It doesn't do any harm for us to re-traverse the - // conditional, so we don't bother hand-traversing the - // rest of the "when", but just let the usual processing - // do it. break; case STMT_FOR: @@ -320,9 +311,6 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) { auto bf = static_cast(func_vf); script_calls.insert(bf); - - if ( in_when ) - when_calls.insert(bf); } else BiF_globals.insert(func); diff --git a/src/script_opt/ProfileFunc.h b/src/script_opt/ProfileFunc.h index 36de4ee5ca..1b0a08e54c 100644 --- a/src/script_opt/ProfileFunc.h +++ b/src/script_opt/ProfileFunc.h @@ -260,10 +260,6 @@ protected: // Whether we should treat record field accesses as absolute // (integer offset) or relative (name-based). bool abs_rec_fields; - - // Whether we're separately processing a "when" condition to - // mine out its script calls. - bool in_when = false; }; // Function pointer for a predicate that determines whether a given From e22d279fdf0498eec6fd76879d44eb8ced9a866e Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 11:50:40 -0800 Subject: [PATCH 03/10] option for internal use to mark a function type as allowing non-expression returns --- src/Type.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Type.h b/src/Type.h index b9514739bd..f29180fb0b 100644 --- a/src/Type.h +++ b/src/Type.h @@ -515,6 +515,20 @@ public: */ const std::optional& GetCaptures() const { return captures; } + /** + * Returns whether it's acceptable for a "return" inside the function + * to not have an expression (even though the function has a return + * type). Used internally for lambdas built for "when" statements. + */ + bool ExpressionlessReturnOkay() const { return expressionless_return_okay; } + + /** + * Sets whether it's acceptable for a "return" inside the function + * to not have an expression (even though the function has a return + * type). Used internally for lambdas built for "when" statements. + */ + void SetExpressionlessReturnOkay(bool is_ok) { expressionless_return_okay = is_ok; } + protected: friend FuncTypePtr make_intrusive(); @@ -526,6 +540,8 @@ protected: std::vector prototypes; std::optional captures; // if nil then no captures specified + // Used for internal lambdas built for "when" statements: + bool expressionless_return_okay = false; }; class TypeType final : public Type From fa142438fe7912a7d55e01dfef003c838c592737 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 14:18:42 -0800 Subject: [PATCH 04/10] logic (other than in profiling) for assignments that yield separate values --- src/Expr.h | 9 +++++++++ src/script_opt/CPP/Exprs.cc | 11 ++++++++++- src/script_opt/Expr.cc | 21 ++++++++++++++++++++- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/Expr.h b/src/Expr.h index 83c34616f5..e8863a59b3 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -959,6 +959,15 @@ public: bool IsTemp() const { return is_temp; } void SetIsTemp() { is_temp = true; } + // The following is a hack that's used in "when" expressions to support + // assignments to new locals, like "when ( (local l = foo()) && ...". + // These methods return the value to use when evaluating such + // assignments. That would normally be the RHS of the assignment, + // but to get when's to work in a convenient fashion, for them it's + // instead boolean T. + ValPtr AssignVal() { return val; } + const ValPtr& AssignVal() const { return val; } + protected: bool TypeCheck(const AttributesPtr& attrs = nullptr); bool TypeCheckArithmetics(TypeTag bt1, TypeTag bt2); diff --git a/src/script_opt/CPP/Exprs.cc b/src/script_opt/CPP/Exprs.cc index f65be5eb84..a9c41a8377 100644 --- a/src/script_opt/CPP/Exprs.cc +++ b/src/script_opt/CPP/Exprs.cc @@ -461,7 +461,16 @@ string CPPCompile::GenAssignExpr(const Expr* e, GenType gt, bool top_level) if ( rhs_is_any && ! lhs_is_any && t1->Tag() != TYPE_LIST ) rhs_native = rhs_val_ptr = GenericValPtrToGT(rhs_val_ptr, t1, GEN_NATIVE); - return GenAssign(op1, op2, rhs_native, rhs_val_ptr, gt, top_level); + auto gen = GenAssign(op1, op2, rhs_native, rhs_val_ptr, gt, top_level); + auto av = e->AsAssignExpr()->AssignVal(); + if ( av ) + { + auto av_e = make_intrusive(av); + auto av_gen = GenExpr(av_e, gt, false); + return string("(") + gen + ", " + av_gen + ")"; + } + else + return gen; } string CPPCompile::GenAddToExpr(const Expr* e, GenType gt, bool top_level) diff --git a/src/script_opt/Expr.cc b/src/script_opt/Expr.cc index b69ee6d30d..b4db48a6ec 100644 --- a/src/script_opt/Expr.cc +++ b/src/script_opt/Expr.cc @@ -1556,6 +1556,11 @@ bool AssignExpr::IsReduced(Reducer* c) const // Cascaded assignments are never reduced. return false; + if ( val ) + // Initializations of "local" variables in "when" statements + // are never reduced. + return false; + const auto& t1 = op1->GetType(); const auto& t2 = op2->GetType(); @@ -1619,6 +1624,16 @@ ExprPtr AssignExpr::Reduce(Reducer* c, StmtPtr& red_stmt) // These are generated for reduced expressions. return ThisPtr(); + if ( val ) + { + // These are reduced to the assignment followed by + // the assignment value. + auto assign_val = make_intrusive(val); + val = nullptr; + red_stmt = make_intrusive(ThisPtr()); + return assign_val; + } + auto& t1 = op1->GetType(); auto& t2 = op2->GetType(); @@ -1757,7 +1772,8 @@ ExprPtr AssignExpr::Reduce(Reducer* c, StmtPtr& red_stmt) ExprPtr AssignExpr::ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) { // Yields a statement performing the assignment and for the - // expression the LHS (but turned into an RHS). + // expression the LHS (but turned into an RHS), or the assignment + // value if present. if ( op1->Tag() != EXPR_REF ) Internal("Confusion in AssignExpr::ReduceToSingleton"); @@ -1765,6 +1781,9 @@ ExprPtr AssignExpr::ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) auto ae_stmt = make_intrusive(assign_expr); red_stmt = ae_stmt->Reduce(c); + if ( val ) + return make_intrusive(val); + return op1->AsRefExprPtr()->GetOp1(); } From f895008c3439744832cea267f6db2d7937e127d7 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 14:50:35 -0800 Subject: [PATCH 05/10] captures for "when" statements update Triggers to IntrusivePtr's and simpler AST traversal introduce IDSet type, migrate associated "ID*" types to "const ID*" --- src/Expr.cc | 31 +++- src/Expr.h | 6 +- src/ID.h | 2 + src/Stmt.cc | 271 ++++++++++++++++++++++++++++++---- src/Stmt.h | 110 ++++++++++++-- src/Trigger.cc | 97 ++++++++---- src/Trigger.h | 50 +++++-- src/Val.h | 1 + src/Var.cc | 19 +++ src/parse.y | 113 ++++++++------ src/script_opt/CPP/Compile.h | 4 +- src/script_opt/CPP/GenFunc.cc | 2 +- src/script_opt/GenIDDefs.h | 2 +- src/script_opt/ProfileFunc.cc | 23 ++- src/script_opt/ProfileFunc.h | 47 +++--- src/script_opt/Reduce.h | 4 +- src/script_opt/Stmt.cc | 17 ++- src/script_opt/UseDefs.h | 4 +- src/script_opt/ZAM/AM-Opt.cc | 4 +- src/script_opt/ZAM/Compile.h | 10 +- src/script_opt/ZAM/Ops.in | 19 ++- src/script_opt/ZAM/Stmt.cc | 2 +- src/script_opt/ZAM/Vars.cc | 6 +- src/script_opt/ZAM/ZInst.h | 6 +- 24 files changed, 648 insertions(+), 202 deletions(-) diff --git a/src/Expr.cc b/src/Expr.cc index c291f288ae..b857ac10de 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -4588,7 +4588,8 @@ void CallExpr::ExprDescribe(ODesc* d) const args->Describe(d); } -LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList arg_outer_ids) +LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList arg_outer_ids, + StmtPtr when_parent) : Expr(EXPR_LAMBDA) { ingredients = std::move(arg_ing); @@ -4596,7 +4597,7 @@ LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList ar SetType(ingredients->id->GetType()); - CheckCaptures(); + CheckCaptures(when_parent); // Install a dummy version of the function globally for use only // when broker provides a closure. @@ -4641,7 +4642,7 @@ LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, IDPList ar id->SetConst(); } -void LambdaExpr::CheckCaptures() +void LambdaExpr::CheckCaptures(StmtPtr when_parent) { auto ft = type->AsFuncType(); const auto& captures = ft->GetCaptures(); @@ -4665,6 +4666,8 @@ void LambdaExpr::CheckCaptures() std::set outer_is_matched; std::set capture_is_matched; + auto desc = when_parent ? "\"when\" statement" : "lambda"; + for ( const auto& c : *captures ) { auto cid = c.id.get(); @@ -4677,7 +4680,11 @@ void LambdaExpr::CheckCaptures() if ( capture_is_matched.count(cid) > 0 ) { - ExprError(util::fmt("%s listed multiple times in capture", cid->Name())); + auto msg = util::fmt("%s listed multiple times in capture", cid->Name()); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); continue; } @@ -4692,13 +4699,25 @@ void LambdaExpr::CheckCaptures() for ( auto id : outer_ids ) if ( outer_is_matched.count(id) == 0 ) - ExprError(util::fmt("%s is used inside lambda but not captured", id->Name())); + { + auto msg = util::fmt("%s is used inside %s but not captured", id->Name(), desc); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); + } for ( const auto& c : *captures ) { auto cid = c.id.get(); if ( cid && capture_is_matched.count(cid) == 0 ) - ExprError(util::fmt("%s is captured but not used inside lambda", cid->Name())); + { + auto msg = util::fmt("%s is captured but not used inside %s", cid->Name(), desc); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); + } } } diff --git a/src/Expr.h b/src/Expr.h index e8863a59b3..24d63a7a36 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -137,6 +137,7 @@ using EventExprPtr = IntrusivePtr; using ExprPtr = IntrusivePtr; using NameExprPtr = IntrusivePtr; using RefExprPtr = IntrusivePtr; +using LambdaExprPtr = IntrusivePtr; class Stmt; using StmtPtr = IntrusivePtr; @@ -1428,7 +1429,8 @@ protected: class LambdaExpr final : public Expr { public: - LambdaExpr(std::unique_ptr ingredients, IDPList outer_ids); + LambdaExpr(std::unique_ptr ingredients, IDPList outer_ids, + StmtPtr when_parent = nullptr); const std::string& Name() const { return my_name; } const IDPList& OuterIDs() const { return outer_ids; } @@ -1449,7 +1451,7 @@ protected: void ExprDescribe(ODesc* d) const override; private: - void CheckCaptures(); + void CheckCaptures(StmtPtr when_parent); std::unique_ptr ingredients; diff --git a/src/ID.h b/src/ID.h index cb9f68997c..28b0a5689d 100644 --- a/src/ID.h +++ b/src/ID.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "zeek/Attr.h" @@ -55,6 +56,7 @@ enum IDScope class ID; using IDPtr = IntrusivePtr; +using IDSet = std::unordered_set; class IDOptInfo; diff --git a/src/Stmt.cc b/src/Stmt.cc index 87c47d0adc..7cbb13d7da 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -20,6 +20,7 @@ #include "zeek/Var.h" #include "zeek/logging/Manager.h" #include "zeek/logging/logging.bif.h" +#include "zeek/script_opt/ProfileFunc.h" #include "zeek/script_opt/StmtOptInfo.h" namespace zeek::detail @@ -1558,7 +1559,7 @@ ReturnStmt::ReturnStmt(ExprPtr arg_e) : ExprStmt(STMT_RETURN, std::move(arg_e)) else if ( ! e ) { - if ( ft->Flavor() != FUNC_FLAVOR_HOOK ) + if ( ft->Flavor() != FUNC_FLAVOR_HOOK && ! ft->ExpressionlessReturnOkay() ) Error("return statement needs expression"); } @@ -1794,45 +1795,257 @@ TraversalCode NullStmt::Traverse(TraversalCallback* cb) const HANDLE_TC_STMT_POST(tc); } -WhenStmt::WhenStmt(ExprPtr arg_cond, StmtPtr arg_s1, StmtPtr arg_s2, ExprPtr arg_timeout, - bool arg_is_return) - : Stmt(STMT_WHEN), cond(std::move(arg_cond)), s1(std::move(arg_s1)), s2(std::move(arg_s2)), - timeout(std::move(arg_timeout)), is_return(arg_is_return) +WhenInfo::WhenInfo(ExprPtr _cond, FuncType::CaptureList* _cl, bool _is_return) + : cond(std::move(_cond)), cl(_cl), is_return(_is_return) { - assert(cond); - assert(s1); + prior_vars = current_scope()->Vars(); + + ProfileFunc cond_pf(cond.get()); + + if ( ! cl ) + { + for ( auto& wl : cond_pf.WhenLocals() ) + prior_vars.erase(wl->Name()); + return; + } + + when_expr_locals = cond_pf.Locals(); + when_expr_globals = cond_pf.Globals(); + + // Make any when-locals part of our captures, if not already present, + // to enable sharing between the condition and the body/timeout code. + for ( auto& wl : cond_pf.WhenLocals() ) + { + bool is_present = false; + + for ( auto& c : *cl ) + if ( c.id == wl ) + { + is_present = true; + break; + } + + if ( ! is_present ) + { + IDPtr wl_ptr = {NewRef{}, const_cast(wl)}; + cl->emplace_back(FuncType::Capture{wl_ptr, false}); + } + + // In addition, don't treat them as external locals that + // existed at the onset. + when_expr_locals.erase(wl); + } + + // Create the internal lambda we'll use to manage the captures. + static int num_params = 0; // to ensure each is distinct + lambda_param_id = util::fmt("when-param-%d", ++num_params); + + auto param_list = new type_decl_list(); + auto count_t = base_type(TYPE_COUNT); + param_list->push_back(new TypeDecl(util::copy_string(lambda_param_id.c_str()), count_t)); + auto params = make_intrusive(param_list); + + auto ft = make_intrusive(params, base_type(TYPE_ANY), FUNC_FLAVOR_FUNCTION); + ft->SetCaptures(*cl); + + if ( ! is_return ) + ft->SetExpressionlessReturnOkay(true); + + auto id = current_scope()->GenerateTemporary("when-internal"); + + // This begin_func will be completed by WhenInfo::Build(). + begin_func(id, current_module.c_str(), FUNC_FLAVOR_FUNCTION, false, ft); + } + +void WhenInfo::Build(StmtPtr ws) + { + if ( ! cl ) + { + // Old-style semantics. + auto locals = when_expr_locals; + + ProfileFunc cond_pf(cond.get()); + for ( auto& bl : cond_pf.Locals() ) + if ( prior_vars.count(bl->Name()) > 0 ) + locals.insert(bl); + + ProfileFunc body_pf(s.get()); + for ( auto& bl : body_pf.Locals() ) + if ( prior_vars.count(bl->Name()) > 0 ) + locals.insert(bl); + + if ( timeout_s ) + { + ProfileFunc to_pf(timeout_s.get()); + for ( auto& tl : to_pf.Locals() ) + if ( prior_vars.count(tl->Name()) > 0 ) + locals.insert(tl); + } + + if ( ! locals.empty() ) + { + std::string vars; + for ( auto& l : locals ) + { + if ( ! vars.empty() ) + vars += ", "; + vars += l->Name(); + } + + std::string msg = util::fmt("\"when\" statement referring to locals without an " + "explicit [] capture is deprecated: %s", + vars.c_str()); + ws->Warn(msg.c_str()); + } + + return; + } + + // Our general strategy is to construct a single lambda (so that + // the values of captures are shared across all of its elements) + // that's used for all three of the "when" components: condition, + // body, and timeout body. The idea is that the lambda is passed + // a single argument that specifies the particular functionality + // to execute (1 = condition, 2 = body, 3 = timeout). It gets tricky + // in that the condition needs to return a boolean, whereas the body + // and timeout *might* return a value (for "return when") constructs, + // or might not (for vanilla "when"). We address that issue by + // (1) making the return type be "any", and (2) introducing elsehwere + // the notion of functions marked as being allowed to have bare + // returns (no associated expression) even though they have a return + // type (to deal with the vanilla "when" case). + + // Build the AST elements of the lambda. + + // First, the constants we'll need. + auto true_const = make_intrusive(val_mgr->True()); + auto one_const = make_intrusive(val_mgr->Count(1)); + auto two_const = make_intrusive(val_mgr->Count(2)); + auto three_const = make_intrusive(val_mgr->Count(3)); + + invoke_cond = make_intrusive(one_const); + invoke_s = make_intrusive(two_const); + invoke_timeout = make_intrusive(three_const); + + // Access to the parameter that selects which action we're doing. + auto param_id = lookup_ID(lambda_param_id.c_str(), current_module.c_str()); + ASSERT(param_id); + auto param = make_intrusive(param_id); + + // Expressions for testing for the latter constants. + auto one_test = make_intrusive(EXPR_EQ, param, one_const); + auto two_test = make_intrusive(EXPR_EQ, param, two_const); + + auto empty = make_intrusive(); + + auto test_cond = make_intrusive(cond); + auto do_test = make_intrusive(one_test, test_cond, empty); + + auto else_branch = timeout_s ? timeout_s : empty; + + auto do_bodies = make_intrusive(two_test, s, else_branch); + auto dummy_return = make_intrusive(true_const); + + auto shebang = make_intrusive(do_test, do_bodies, dummy_return); + + auto ingredients = std::make_unique(current_scope(), shebang); + auto outer_ids = gather_outer_ids(pop_scope(), ingredients->body); + + lambda = make_intrusive(std::move(ingredients), std::move(outer_ids), ws); + } + +void WhenInfo::Instantiate(Frame* f) + { + if ( cl ) + curr_lambda = make_intrusive(lambda->Eval(f)); + } + +ExprPtr WhenInfo::Cond() + { + if ( ! curr_lambda ) + return cond; + + return make_intrusive(curr_lambda, invoke_cond); + } + +StmtPtr WhenInfo::WhenBody() + { + if ( ! curr_lambda ) + return s; + + auto invoke = make_intrusive(curr_lambda, invoke_s); + return make_intrusive(invoke, true); + } + +StmtPtr WhenInfo::TimeoutStmt() + { + if ( ! curr_lambda ) + return timeout_s; + + auto invoke = make_intrusive(curr_lambda, invoke_timeout); + return make_intrusive(invoke, true); + } + +WhenStmt::WhenStmt(WhenInfo* _wi) : Stmt(STMT_WHEN), wi(_wi) + { + wi->Build(ThisPtr()); + + auto cond = wi->Cond(); if ( ! cond->IsError() && ! IsBool(cond->GetType()->Tag()) ) cond->Error("conditional in test must be boolean"); - if ( timeout ) + auto te = wi->TimeoutExpr(); + + if ( te ) { - if ( timeout->IsError() ) + if ( te->IsError() ) return; - TypeTag bt = timeout->GetType()->Tag(); + TypeTag bt = te->GetType()->Tag(); if ( bt != TYPE_TIME && bt != TYPE_INTERVAL ) - cond->Error("when timeout requires a time or time interval"); + te->Error("when timeout requires a time or time interval"); } } -WhenStmt::~WhenStmt() = default; +WhenStmt::~WhenStmt() + { + delete wi; + } ValPtr WhenStmt::Exec(Frame* f, StmtFlowType& flow) { RegisterAccess(); flow = FLOW_NEXT; - // The new trigger object will take care of its own deletion. - new trigger::Trigger(IntrusivePtr{cond}.release(), IntrusivePtr{s1}.release(), - IntrusivePtr{s2}.release(), IntrusivePtr{timeout}.release(), f, is_return, - location); + wi->Instantiate(f); + + if ( wi->Captures() ) + { + std::vector local_aggrs; + for ( auto& l : wi->WhenExprLocals() ) + { + IDPtr l_ptr = {NewRef{}, const_cast(l)}; + auto v = f->GetElementByID(l_ptr); + if ( v && v->Modifiable() ) + local_aggrs.emplace_back(std::move(v)); + } + + new trigger::Trigger(wi, wi->WhenExprGlobals(), local_aggrs, f, location); + } + + else + // The new trigger object will take care of its own deletion. + new trigger::Trigger(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), wi->TimeoutExpr(), f, + wi->IsReturn(), location); + return nullptr; } bool WhenStmt::IsPure() const { - return cond->IsPure() && s1->IsPure() && (! s2 || s2->IsPure()); + return wi->Cond()->IsPure() && wi->WhenBody()->IsPure() && + (! wi->TimeoutStmt() || wi->TimeoutStmt()->IsPure()); } void WhenStmt::StmtDescribe(ODesc* d) const @@ -1842,33 +2055,33 @@ void WhenStmt::StmtDescribe(ODesc* d) const if ( d->IsReadable() ) d->Add("("); - cond->Describe(d); + wi->Cond()->Describe(d); if ( d->IsReadable() ) d->Add(")"); d->SP(); d->PushIndent(); - s1->AccessStats(d); - s1->Describe(d); + wi->WhenBody()->AccessStats(d); + wi->WhenBody()->Describe(d); d->PopIndent(); - if ( s2 ) + if ( wi->TimeoutStmt() ) { if ( d->IsReadable() ) { d->SP(); d->Add("timeout"); d->SP(); - timeout->Describe(d); + wi->TimeoutExpr()->Describe(d); d->SP(); d->PushIndent(); - s2->AccessStats(d); - s2->Describe(d); + wi->TimeoutStmt()->AccessStats(d); + wi->TimeoutStmt()->Describe(d); d->PopIndent(); } else - s2->Describe(d); + wi->TimeoutStmt()->Describe(d); } } @@ -1877,15 +2090,15 @@ TraversalCode WhenStmt::Traverse(TraversalCallback* cb) const TraversalCode tc = cb->PreStmt(this); HANDLE_TC_STMT_PRE(tc); - tc = cond->Traverse(cb); + tc = wi->Cond()->Traverse(cb); HANDLE_TC_STMT_PRE(tc); - tc = s1->Traverse(cb); + tc = wi->WhenBody()->Traverse(cb); HANDLE_TC_STMT_PRE(tc); - if ( s2 ) + if ( wi->TimeoutStmt() ) { - tc = s2->Traverse(cb); + tc = wi->TimeoutStmt()->Traverse(cb); HANDLE_TC_STMT_PRE(tc); } diff --git a/src/Stmt.h b/src/Stmt.h index b6161f2fd3..97038a2229 100644 --- a/src/Stmt.h +++ b/src/Stmt.h @@ -5,8 +5,10 @@ // Zeek statements. #include "zeek/Dict.h" +#include "zeek/Expr.h" #include "zeek/ID.h" #include "zeek/StmtBase.h" +#include "zeek/Type.h" #include "zeek/ZeekList.h" namespace zeek::detail @@ -442,7 +444,7 @@ public: // Optimization-related: StmtPtr Duplicate() override; - // Constructor used for duplication, when we've already done + // Constructor used internally, for when we've already done // all of the type-checking. ReturnStmt(ExprPtr e, bool ignored); @@ -535,21 +537,105 @@ public: StmtPtr Duplicate() override { return SetSucc(new NullStmt()); } }; +// A helper class for tracking all of the information associated with +// a "when" statement, and constructing the necessary components in support +// of lambda-style captures. +class WhenInfo + { +public: + // Takes ownership of the CaptureList, which if nil signifies + // old-style frame semantics. + WhenInfo(ExprPtr _cond, FuncType::CaptureList* _cl, bool _is_return); + ~WhenInfo() { delete cl; } + + void AddBody(StmtPtr _s) { s = std::move(_s); } + + void AddTimeout(ExprPtr _timeout, StmtPtr _timeout_s) + { + timeout = std::move(_timeout); + timeout_s = std::move(_timeout_s); + } + + // Complete construction of the associated internals, including + // the (complex) lambda used to access the different elements of + // the statement. + void Build(StmtPtr ws); + + // Instantiate a new instance. + void Instantiate(Frame* f); + + // For old-style semantics, the following simply return the + // individual "when" components. For capture semantics, however, + // these instead return different invocations of a lambda that + // manages the captures. + ExprPtr Cond(); + StmtPtr WhenBody(); + + ExprPtr TimeoutExpr() { return timeout; } + StmtPtr TimeoutStmt(); + + FuncType::CaptureList* Captures() { return cl; } + + bool IsReturn() const { return is_return; } + + const LambdaExprPtr& Lambda() const { return lambda; } + + // The locals and globals used in the conditional expression + // (other than newly introduced locals), necessary for registering + // the associated triggers for when their values change. + const IDSet& WhenExprLocals() const { return when_expr_locals; } + const IDSet& WhenExprGlobals() const { return when_expr_globals; } + +private: + ExprPtr cond; + StmtPtr s; + ExprPtr timeout; + StmtPtr timeout_s; + FuncType::CaptureList* cl; + + bool is_return = false; + + // The name of parameter passed ot the lambda. + std::string lambda_param_id; + + // The expression for constructing the lambda. + LambdaExprPtr lambda; + + // The current instance of the lambda. Created by Instantiate(), + // for immediate use via calls to Cond() etc. + ConstExprPtr curr_lambda; + + // Arguments to use when calling the lambda to either evaluate + // the conditional, or execute the body or the timeout statement. + ListExprPtr invoke_cond; + ListExprPtr invoke_s; + ListExprPtr invoke_timeout; + + IDSet when_expr_locals; + IDSet when_expr_globals; + + // Used for identifying deprecated instances. Holds all of the local + // variables in the scope prior to parsing the "when" statement. + std::map> prior_vars; + }; + class WhenStmt final : public Stmt { public: - // s2 is null if no timeout block given. - WhenStmt(ExprPtr cond, StmtPtr s1, StmtPtr s2, ExprPtr timeout, bool is_return); + // The constructor takes ownership of the WhenInfo object. + WhenStmt(WhenInfo* wi); ~WhenStmt() override; ValPtr Exec(Frame* f, StmtFlowType& flow) override; bool IsPure() const override; - const Expr* Cond() const { return cond.get(); } - const Stmt* Body() const { return s1.get(); } - const Expr* TimeoutExpr() const { return timeout.get(); } - const Stmt* TimeoutBody() const { return s2.get(); } - bool IsReturn() const { return is_return; } + ExprPtr Cond() const { return wi->Cond(); } + StmtPtr Body() const { return wi->WhenBody(); } + ExprPtr TimeoutExpr() const { return wi->TimeoutExpr(); } + StmtPtr TimeoutBody() const { return wi->TimeoutStmt(); } + bool IsReturn() const { return wi->IsReturn(); } + + const WhenInfo* Info() const { return wi; } void StmtDescribe(ODesc* d) const override; @@ -561,12 +647,8 @@ public: bool IsReduced(Reducer* c) const override; -protected: - ExprPtr cond; - StmtPtr s1; - StmtPtr s2; - ExprPtr timeout; - bool is_return; +private: + WhenInfo* wi; }; // Internal statement used for inlining. Executes a block and stops diff --git a/src/Trigger.cc b/src/Trigger.cc index ba9825071f..9d4e2ea121 100644 --- a/src/Trigger.cc +++ b/src/Trigger.cc @@ -23,21 +23,19 @@ using namespace zeek::detail::trigger; namespace zeek::detail::trigger { +// Used to extract the globals and locals seen in a trigger expression. class TriggerTraversalCallback : public TraversalCallback { public: - TriggerTraversalCallback(Trigger* arg_trigger) + TriggerTraversalCallback(IDSet& _globals, IDSet& _locals) : globals(_globals), locals(_locals) { - Ref(arg_trigger); - trigger = arg_trigger; } - ~TriggerTraversalCallback() { Unref(trigger); } - virtual TraversalCode PreExpr(const Expr*) override; private: - Trigger* trigger; + IDSet& globals; + IDSet& locals; }; TraversalCode trigger::TriggerTraversalCallback::PreExpr(const Expr* expr) @@ -50,14 +48,12 @@ TraversalCode trigger::TriggerTraversalCallback::PreExpr(const Expr* expr) case EXPR_NAME: { const auto* e = static_cast(expr); - if ( e->Id()->IsGlobal() ) - trigger->Register(e->Id()); + auto id = e->Id(); - Val* v = e->Id()->GetVal().get(); - - if ( v && v->Modifiable() ) - trigger->Register(v); - break; + if ( id->IsGlobal() ) + globals.insert(id); + else + locals.insert(id); }; default: @@ -103,10 +99,35 @@ protected: double time; }; -Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeout_expr, +Trigger::Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, ExprPtr timeout_expr, Frame* frame, bool is_return, const Location* location) { - timeout_value = -1; + GetTimeout(timeout_expr); + Init(cond, body, timeout_stmts, frame, is_return, location); + } + +Trigger::Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, double timeout, Frame* frame, + bool is_return, const Location* location) + { + timeout_value = timeout; + Init(cond, body, timeout_stmts, frame, is_return, location); + } + +Trigger::Trigger(WhenInfo* wi, const IDSet& _globals, std::vector _local_aggrs, Frame* f, + const Location* loc) + { + globals = _globals; + local_aggrs = std::move(_local_aggrs); + have_trigger_elems = true; + + GetTimeout(wi->TimeoutExpr()); + + Init(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), f, wi->IsReturn(), loc); + } + +void Trigger::GetTimeout(const ExprPtr& timeout_expr) + { + timeout_value = -1.0; if ( timeout_expr ) { @@ -123,24 +144,14 @@ Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeou if ( timeout_val ) timeout_value = timeout_val->AsInterval(); } - - Init(cond, body, timeout_stmts, frame, is_return, location); } -Trigger::Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, double timeout, Frame* frame, - bool is_return, const Location* location) - { - timeout_value = timeout; - Init(cond, body, timeout_stmts, frame, is_return, location); - } - -void Trigger::Init(const Expr* arg_cond, Stmt* arg_body, Stmt* arg_timeout_stmts, Frame* arg_frame, +void Trigger::Init(ExprPtr arg_cond, StmtPtr arg_body, StmtPtr arg_timeout_stmts, Frame* arg_frame, bool arg_is_return, const Location* arg_location) { cond = arg_cond; body = arg_body; timeout_stmts = arg_timeout_stmts; - frame = arg_frame->Clone(); timer = nullptr; delayed = false; disabled = false; @@ -148,6 +159,11 @@ void Trigger::Init(const Expr* arg_cond, Stmt* arg_body, Stmt* arg_timeout_stmts is_return = arg_is_return; location = arg_location; + if ( arg_frame ) + frame = arg_frame->Clone(); + else + frame = nullptr; + DBG_LOG(DBG_NOTIFIERS, "%s: instantiating", Name()); if ( is_return ) @@ -217,8 +233,30 @@ void Trigger::ReInit(std::vector index_expr_results) { assert(! disabled); UnregisterAll(); - TriggerTraversalCallback cb(this); - cond->Traverse(&cb); + + if ( ! have_trigger_elems ) + { + TriggerTraversalCallback cb(globals, locals); + cond->Traverse(&cb); + have_trigger_elems = true; + } + + for ( auto g : globals ) + { + Register(g); + + auto& v = g->GetVal(); + if ( v && v->Modifiable() ) + Register(v.get()); + } + + for ( auto l : locals ) + { + ASSERT(! l->GetVal()); + } + + for ( auto& av : local_aggrs ) + Register(av.get()); for ( const auto& v : index_expr_results ) Register(v.get()); @@ -390,9 +428,10 @@ void Trigger::Timeout() Unref(this); } -void Trigger::Register(ID* id) +void Trigger::Register(const ID* const_id) { assert(! disabled); + ID* id = const_cast(const_id); notifier::detail::registry.Register(id, this); Ref(id); diff --git a/src/Trigger.h b/src/Trigger.h index da928c6a89..6810a7fccd 100644 --- a/src/Trigger.h +++ b/src/Trigger.h @@ -4,6 +4,7 @@ #include #include +#include "zeek/ID.h" #include "zeek/IntrusivePtr.h" #include "zeek/Notifier.h" #include "zeek/Obj.h" @@ -16,6 +17,8 @@ namespace zeek class ODesc; class Val; +using ValPtr = IntrusivePtr; + namespace detail { @@ -24,6 +27,9 @@ class Stmt; class Expr; class CallExpr; class ID; +class WhenInfo; + +using StmtPtr = IntrusivePtr; namespace trigger { @@ -41,10 +47,18 @@ public: // instantiation. Note that if the condition is already true, the // statements are executed immediately and the object is deleted // right away. - Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Expr* timeout, Frame* f, + + // These first two constructors are for the deprecated deep-copy + // semantics. + Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, ExprPtr timeout, Frame* f, bool is_return, const Location* loc); - Trigger(const Expr* cond, Stmt* body, Stmt* timeout_stmts, double timeout, Frame* f, + Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, double timeout, Frame* f, bool is_return, const Location* loc); + + // Used for capture-list semantics. + Trigger(WhenInfo* wi, const IDSet& globals, std::vector local_aggrs, Frame* f, + const Location* loc); + ~Trigger() override; // Evaluates the condition. If true, executes the body and deletes @@ -96,22 +110,23 @@ public: const char* Name() const; private: - friend class TriggerTraversalCallback; friend class TriggerTimer; - void Init(const Expr* cond, Stmt* body, Stmt* timeout_stmts, Frame* frame, bool is_return, + void GetTimeout(const ExprPtr& timeout_expr); + + void Init(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, Frame* frame, bool is_return, const Location* location); - void ReInit(std::vector> index_expr_results); + void ReInit(std::vector index_expr_results); - void Register(ID* id); + void Register(const ID* id); void Register(Val* val); void UnregisterAll(); - const Expr* cond; - Stmt* body; - Stmt* timeout_stmts; - Expr* timeout; + ExprPtr cond; + StmtPtr body; + StmtPtr timeout_stmts; + ExprPtr timeout; double timeout_value; Frame* frame; bool is_return; @@ -123,14 +138,25 @@ private: bool delayed; // true if a function call is currently being delayed bool disabled; + // Globals and locals present in the when expression. + IDSet globals; + IDSet locals; // not needed, present only for matching deprecated logic + + // Tracks whether we've found the globals/locals, as the work only + // has to be done once. + bool have_trigger_elems = false; + + // Aggregate values seen in locals used in the trigger condition, + // so we can detect changes in them that affect whether the condition + // holds. + std::vector local_aggrs; + std::vector> objs; using ValCache = std::map; ValCache cache; }; -using TriggerPtr = IntrusivePtr; - class Manager final : public iosource::IOSource { public: diff --git a/src/Val.h b/src/Val.h index 6ece8c90de..d40bf6db77 100644 --- a/src/Val.h +++ b/src/Val.h @@ -82,6 +82,7 @@ class TypeVal; using AddrValPtr = IntrusivePtr; using EnumValPtr = IntrusivePtr; +using FuncValPtr = IntrusivePtr; using ListValPtr = IntrusivePtr; using PortValPtr = IntrusivePtr; using RecordValPtr = IntrusivePtr; diff --git a/src/Var.cc b/src/Var.cc index 2cf367228e..39e03be262 100644 --- a/src/Var.cc +++ b/src/Var.cc @@ -680,6 +680,7 @@ class OuterIDBindingFinder : public TraversalCallback public: OuterIDBindingFinder(ScopePtr s) { scopes.emplace_back(s); } + TraversalCode PreStmt(const Stmt*) override; TraversalCode PreExpr(const Expr*) override; TraversalCode PostExpr(const Expr*) override; @@ -687,6 +688,24 @@ public: std::unordered_set outer_id_references; }; +TraversalCode OuterIDBindingFinder::PreStmt(const Stmt* stmt) + { + if ( stmt->Tag() != STMT_WHEN ) + return TC_CONTINUE; + + auto ws = static_cast(stmt); + auto lambda = ws->Info()->Lambda(); + + if ( ! lambda ) + // Old-style semantics. + return TC_CONTINUE; + + // The semantics of identifiers for the "when" statement are those + // of the lambda it's transformed into. + lambda->Traverse(this); + return TC_ABORTSTMT; + } + TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr) { if ( expr->Tag() == EXPR_LAMBDA ) diff --git a/src/parse.y b/src/parse.y index f86da73ccc..b52cf95a93 100644 --- a/src/parse.y +++ b/src/parse.y @@ -5,7 +5,7 @@ // Switching parser table type fixes ambiguity problems. %define lr.type ielr -%expect 141 +%expect 140 %token TOK_ADD TOK_ADD_TO TOK_ADDR TOK_ANY %token TOK_ATENDIF TOK_ATELSE TOK_ATIF TOK_ATIFDEF TOK_ATIFNDEF @@ -52,7 +52,7 @@ %left '$' '[' ']' '(' ')' TOK_HAS_FIELD TOK_HAS_ATTR %nonassoc TOK_AS TOK_IS -%type opt_no_test opt_no_test_block TOK_PATTERN_END opt_deep +%type opt_no_test opt_no_test_block TOK_PATTERN_END opt_deep when_flavor %type TOK_ID TOK_PATTERN_TEXT %type local_id global_id def_global_id event_id global_or_event_id resolve_id begin_lambda case_type %type local_id_list case_type_list @@ -73,7 +73,8 @@ %type attr %type attr_list opt_attr %type capture -%type capture_list opt_captures +%type capture_list opt_captures when_captures +%type when_head when_start when_clause %{ #include @@ -309,7 +310,8 @@ static StmtPtr build_local(ID* id, Type* t, InitClass ic, Expr* e, std::vector* attr_l; zeek::detail::AttrTag attrtag; zeek::FuncType::Capture* capture; - std::vector* captures; + zeek::FuncType::CaptureList* captures; + zeek::detail::WhenInfo* when_clause; } %% @@ -348,6 +350,54 @@ opt_expr: { $$ = 0; } ; +when_clause: + when_head TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' + { + set_location(@1, @7); + $1->AddTimeout({AdoptRef{}, $3}, {AdoptRef{}, $6}); + if ( $5 ) + script_coverage_mgr.DecIgnoreDepth(); + } + | + when_head + ; + +when_head: + when_start stmt + { + set_location(@1, @2); + $1->AddBody({AdoptRef{}, $2}); + } + ; + +when_start: + when_flavor '[' when_captures ']' '(' when_condition ')' + { + set_location(@1, @7); + $$ = new WhenInfo({AdoptRef{}, $6}, $3, $1); + } + + | when_flavor '(' when_condition ')' + { + set_location(@1, @4); + $$ = new WhenInfo({AdoptRef{}, $3}, nullptr, $1); + } + ; + +when_flavor: + TOK_RETURN TOK_WHEN + { $$ = true; } + | + TOK_WHEN + { $$ = false; } + ; + +when_captures: + capture_list + | + { $$ = new zeek::FuncType::CaptureList; } + ; + when_condition: { ++in_when_cond; } expr { --in_when_cond; } { $$ = $2; } @@ -1385,15 +1435,7 @@ begin_lambda: if ( $1 ) { - captures = FuncType::CaptureList{}; - captures->reserve($1->size()); - - for ( auto c : *$1 ) - { - captures->emplace_back(*c); - delete c; - } - + captures = *$1; delete $1; } @@ -1411,11 +1453,15 @@ opt_captures: capture_list: capture_list ',' capture - { $1->push_back($3); } + { + $1->push_back(*$3); + delete $3; + } | capture { - $$ = new std::vector; - $$->push_back($1); + $$ = new zeek::FuncType::CaptureList; + $$->push_back(*$1); + delete $1; } ; @@ -1423,8 +1469,7 @@ capture: opt_deep TOK_ID { set_location(@2); - auto id = lookup_ID($2, - current_module.c_str()); + auto id = lookup_ID($2, current_module.c_str()); if ( ! id ) reporter->Error("no such local identifier: %s", $2); @@ -1722,37 +1767,9 @@ stmt: $$ = build_local($2, $3, $4, $5, $6, VAR_CONST, ! $8).release(); } - | TOK_WHEN '(' when_condition ')' stmt + | when_clause { - set_location(@3, @5); - $$ = new WhenStmt({AdoptRef{}, $3}, {AdoptRef{}, $5}, - nullptr, nullptr, false); - } - - | TOK_WHEN '(' when_condition ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' - { - set_location(@3, @9); - $$ = new WhenStmt({AdoptRef{}, $3}, {AdoptRef{}, $5}, - {AdoptRef{}, $10}, {AdoptRef{}, $7}, false); - if ( $9 ) - script_coverage_mgr.DecIgnoreDepth(); - } - - - | TOK_RETURN TOK_WHEN '(' when_condition ')' stmt - { - set_location(@4, @6); - $$ = new WhenStmt({AdoptRef{}, $4}, {AdoptRef{}, $6}, nullptr, - nullptr, true); - } - - | TOK_RETURN TOK_WHEN '(' when_condition ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' - { - set_location(@4, @10); - $$ = new WhenStmt({AdoptRef{}, $4}, {AdoptRef{}, $6}, - {AdoptRef{}, $11}, {AdoptRef{}, $8}, true); - if ( $10 ) - script_coverage_mgr.DecIgnoreDepth(); + $$ = new WhenStmt($1); } | index_slice '=' expr ';' opt_no_test diff --git a/src/script_opt/CPP/Compile.h b/src/script_opt/CPP/Compile.h index 11d8808a75..f2fa2f3874 100644 --- a/src/script_opt/CPP/Compile.h +++ b/src/script_opt/CPP/Compile.h @@ -435,7 +435,7 @@ private: std::unordered_map events; // Globals that correspond to variables, not functions. - std::unordered_set global_vars; + IDSet global_vars; // // End of methods related to script/C++ variables. @@ -539,7 +539,7 @@ private: std::unordered_map lambda_names; // The function's parameters. Tracked so we don't re-declare them. - std::unordered_set params; + IDSet params; // Whether we're compiling a hook. bool in_hook = false; diff --git a/src/script_opt/CPP/GenFunc.cc b/src/script_opt/CPP/GenFunc.cc index 673e6def4a..f4071c7deb 100644 --- a/src/script_opt/CPP/GenFunc.cc +++ b/src/script_opt/CPP/GenFunc.cc @@ -149,7 +149,7 @@ void CPPCompile::InitializeEvents(const ProfileFunc* pf) void CPPCompile::DeclareLocals(const ProfileFunc* pf, const IDPList* lambda_ids) { // It's handy to have a set of the lambda captures rather than a list. - unordered_set lambda_set; + IDSet lambda_set; if ( lambda_ids ) for ( auto li : *lambda_ids ) lambda_set.insert(li); diff --git a/src/script_opt/GenIDDefs.h b/src/script_opt/GenIDDefs.h index 41322f498f..a18c01928e 100644 --- a/src/script_opt/GenIDDefs.h +++ b/src/script_opt/GenIDDefs.h @@ -105,7 +105,7 @@ private: // the front entry tracks identifiers at the outermost // (non-confluence) scope. Thus, to index it for a given // confluence block i, we need to use i+1. - std::vector> modified_IDs; + std::vector modified_IDs; // If non-zero, indicates we should suspend any generation // of usage errors. A counter rather than a boolean because diff --git a/src/script_opt/ProfileFunc.cc b/src/script_opt/ProfileFunc.cc index 096604df3f..234f09af4f 100644 --- a/src/script_opt/ProfileFunc.cc +++ b/src/script_opt/ProfileFunc.cc @@ -270,13 +270,26 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) case EXPR_REMOVE_FROM: case EXPR_ASSIGN: { - if ( e->GetOp1()->Tag() == EXPR_REF ) + if ( e->GetOp1()->Tag() != EXPR_REF ) + // this isn't a direct assignment + break; + + auto lhs = e->GetOp1()->GetOp1(); + if ( lhs->Tag() != EXPR_NAME ) + break; + + auto id = lhs->AsNameExpr()->Id(); + TrackAssignment(id); + + if ( e->Tag() == EXPR_ASSIGN ) { - auto lhs = e->GetOp1()->GetOp1(); - if ( lhs->Tag() == EXPR_NAME ) - TrackAssignment(lhs->AsNameExpr()->Id()); + auto a_e = static_cast(e); + auto& av = a_e->AssignVal(); + if ( av ) + // This is a funky "local" assignment + // inside a when clause. + when_locals.insert(id); } - // else this isn't a direct assignment. break; } diff --git a/src/script_opt/ProfileFunc.h b/src/script_opt/ProfileFunc.h index 1b0a08e54c..ed839167b4 100644 --- a/src/script_opt/ProfileFunc.h +++ b/src/script_opt/ProfileFunc.h @@ -97,23 +97,23 @@ public: // See the comments for the associated member variables for each // of these accessors. - const std::unordered_set& Globals() const { return globals; } - const std::unordered_set& AllGlobals() const { return all_globals; } - const std::unordered_set& Locals() const { return locals; } - const std::unordered_set& Params() const { return params; } + const IDSet& Globals() const { return globals; } + const IDSet& AllGlobals() const { return all_globals; } + const IDSet& Locals() const { return locals; } + const IDSet& WhenLocals() const { return when_locals; } + const IDSet& Params() const { return params; } const std::unordered_map& Assignees() const { return assignees; } - const std::unordered_set& Inits() const { return inits; } + const IDSet& Inits() const { return inits; } const std::vector& Stmts() const { return stmts; } const std::vector& Exprs() const { return exprs; } const std::vector& Lambdas() const { return lambdas; } const std::vector& Constants() const { return constants; } - const std::unordered_set& UnorderedIdentifiers() const { return ids; } + const IDSet& UnorderedIdentifiers() const { return ids; } const std::vector& OrderedIdentifiers() const { return ordered_ids; } const std::unordered_set& UnorderedTypes() const { return types; } const std::vector& OrderedTypes() const { return ordered_types; } const std::unordered_set& ScriptCalls() const { return script_calls; } - const std::unordered_set& BiFGlobals() const { return BiF_globals; } - const std::unordered_set& WhenCalls() const { return when_calls; } + const IDSet& BiFGlobals() const { return BiF_globals; } const std::unordered_set& Events() const { return events; } const std::unordered_set& ConstructorAttrs() const { @@ -162,17 +162,20 @@ protected: // // Does *not* include globals solely seen as the function being // called in a call. - std::unordered_set globals; + IDSet globals; // Same, but also includes globals only seen as called functions. - std::unordered_set all_globals; + IDSet all_globals; // Locals seen in the function. - std::unordered_set locals; + IDSet locals; + + // Same, but for those declared in "when" expressions. + IDSet when_locals; // The function's parameters. Empty if our starting point was // profiling an expression. - std::unordered_set params; + IDSet params; // How many parameters the function has. The default value flags // that we started the profile with an expression rather than a @@ -187,7 +190,7 @@ protected: // Same for locals seen in initializations, so we can find, // for example, unused aggregates. - std::unordered_set inits; + IDSet inits; // Statements seen in the function. Does not include indirect // statements, such as those in lambda bodies. @@ -203,13 +206,13 @@ protected: std::vector lambdas; // If we're profiling a lambda function, this holds the captures. - std::unordered_set captures; + IDSet captures; // Constants seen in the function. std::vector constants; // Identifiers seen in the function. - std::unordered_set ids; + IDSet ids; // The same, but in a deterministic order. std::vector ordered_ids; @@ -226,7 +229,7 @@ protected: // Same for BiF's, though for them we record the corresponding global // rather than the BuiltinFunc*. - std::unordered_set BiF_globals; + IDSet BiF_globals; // Script functions appearing in "when" clauses. std::unordered_set when_calls; @@ -283,13 +286,13 @@ public: // The following accessors provide a global profile across all of // the (non-skipped) functions in "funcs". See the comments for // the associated member variables for documentation. - const std::unordered_set& Globals() const { return globals; } - const std::unordered_set& AllGlobals() const { return all_globals; } + const IDSet& Globals() const { return globals; } + const IDSet& AllGlobals() const { return all_globals; } const std::unordered_set& Constants() const { return constants; } const std::vector& MainTypes() const { return main_types; } const std::vector& RepTypes() const { return rep_types; } const std::unordered_set& ScriptCalls() const { return script_calls; } - const std::unordered_set& BiFGlobals() const { return BiF_globals; } + const IDSet& BiFGlobals() const { return BiF_globals; } const std::unordered_set& Lambdas() const { return lambdas; } const std::unordered_set& Events() const { return events; } @@ -345,10 +348,10 @@ protected: // Globals seen across the functions, other than those solely seen // as the function being called in a call. - std::unordered_set globals; + IDSet globals; // Same, but also includes globals only seen as called functions. - std::unordered_set all_globals; + IDSet all_globals; // Constants seen across the functions. std::unordered_set constants; @@ -369,7 +372,7 @@ protected: std::unordered_set script_calls; // Same for BiF's. - std::unordered_set BiF_globals; + IDSet BiF_globals; // And for lambda's. std::unordered_set lambdas; diff --git a/src/script_opt/Reduce.h b/src/script_opt/Reduce.h index 853ab5e7ab..931bf745fd 100644 --- a/src/script_opt/Reduce.h +++ b/src/script_opt/Reduce.h @@ -229,7 +229,7 @@ protected: std::unordered_map> ids_to_temps; // Local variables created during reduction/optimization. - std::unordered_set new_locals; + IDSet new_locals; // Mapping of original identifiers to new locals. Used to // rename local variables when inlining. @@ -262,7 +262,7 @@ protected: // Tracks which (non-temporary) variables had constant // values used for constant propagation. - std::unordered_set constant_vars; + IDSet constant_vars; // Statement at which the current reduction started. StmtPtr reduction_root = nullptr; diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index 68732ffdd0..27a55c5658 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -910,12 +910,19 @@ StmtPtr InitStmt::DoReduce(Reducer* c) StmtPtr WhenStmt::Duplicate() { - auto cond_d = cond->Duplicate(); - auto s1_d = s1->Duplicate(); - auto s2_d = s2 ? s2->Duplicate() : nullptr; - auto timeout_d = timeout ? timeout->Duplicate() : nullptr; + FuncType::CaptureList* cl_dup = nullptr; - return SetSucc(new WhenStmt(cond_d, s1_d, s2_d, timeout_d, is_return)); + if ( wi->Captures() ) + { + cl_dup = new FuncType::CaptureList; + *cl_dup = *wi->Captures(); + } + + auto new_wi = new WhenInfo(Cond(), cl_dup, IsReturn()); + new_wi->AddBody(Body()); + new_wi->AddTimeout(TimeoutExpr(), TimeoutBody()); + + return SetSucc(new WhenStmt(wi)); } void WhenStmt::Inline(Inliner* inl) diff --git a/src/script_opt/UseDefs.h b/src/script_opt/UseDefs.h index ac30afde46..f453a6f721 100644 --- a/src/script_opt/UseDefs.h +++ b/src/script_opt/UseDefs.h @@ -33,7 +33,7 @@ public: void Add(const ID* id) { use_defs.insert(id); } void Remove(const ID* id) { use_defs.erase(id); } - const std::unordered_set& IterateOver() const { return use_defs; } + const IDSet& IterateOver() const { return use_defs; } void Dump() const; void DumpNL() const @@ -43,7 +43,7 @@ public: } protected: - std::unordered_set use_defs; + IDSet use_defs; }; class Reducer; diff --git a/src/script_opt/ZAM/AM-Opt.cc b/src/script_opt/ZAM/AM-Opt.cc index 5af22bee01..fa8b81ac57 100644 --- a/src/script_opt/ZAM/AM-Opt.cc +++ b/src/script_opt/ZAM/AM-Opt.cc @@ -656,7 +656,7 @@ void ZAMCompiler::ReMapInterpreterFrame() remapped_intrp_frame_sizes[func] = next_interp_slot; } -void ZAMCompiler::ReMapVar(ID* id, int slot, bro_uint_t inst) +void ZAMCompiler::ReMapVar(const ID* id, int slot, bro_uint_t inst) { // A greedy algorithm for this is to simply find the first suitable // frame slot. We do that with one twist: we also look for a @@ -832,7 +832,7 @@ void ZAMCompiler::ExtendLifetime(int slot, const ZInstI* inst) if ( inst_endings.count(inst) == 0 ) { - std::unordered_set denizens; + IDSet denizens; inst_endings[inst] = denizens; } diff --git a/src/script_opt/ZAM/Compile.h b/src/script_opt/ZAM/Compile.h index a62eaae08f..ec6cad51bb 100644 --- a/src/script_opt/ZAM/Compile.h +++ b/src/script_opt/ZAM/Compile.h @@ -349,10 +349,10 @@ private: bool IsUnused(const IDPtr& id, const Stmt* where) const; - void LoadParam(ID* id); - const ZAMStmt LoadGlobal(ID* id); + void LoadParam(const ID* id); + const ZAMStmt LoadGlobal(const ID* id); - int AddToFrame(ID*); + int AddToFrame(const ID*); int FrameSlot(const IDPtr& id) { return FrameSlot(id.get()); } int FrameSlot(const ID* id); @@ -420,7 +420,7 @@ private: // Computes the remapping for a variable currently in the given slot, // whose scope begins at the given instruction. - void ReMapVar(ID* id, int slot, bro_uint_t inst); + void ReMapVar(const ID* id, int slot, bro_uint_t inst); // Look to initialize the beginning of local lifetime based on slot // assignment at instruction inst. @@ -541,7 +541,7 @@ private: // A type for mapping an instruction to a set of locals associated // with it. - using AssociatedLocals = std::unordered_map>; + using AssociatedLocals = std::unordered_map; // Maps (live) instructions to which frame denizens begin their // lifetime via an initialization at that instruction, if any ... diff --git a/src/script_opt/ZAM/Ops.in b/src/script_opt/ZAM/Ops.in index 2ba97d526d..3a191aded3 100644 --- a/src/script_opt/ZAM/Ops.in +++ b/src/script_opt/ZAM/Ops.in @@ -1740,20 +1740,23 @@ eval (*tiv_ptr)[z.v1].Clear(); op When op1-read type VVVV -eval auto when_body = new ZAMResumption(this, z.v2); - auto timeout_body = new ZAMResumption(this, z.v3); - new trigger::Trigger(z.e, when_body, timeout_body, frame[z.v1].double_val, f, z.v4, z.loc); +eval auto when_body = make_intrusive(this, z.v2); + auto timeout_body = make_intrusive(this, z.v3); + ExprPtr when_cond = {NewRef{}, const_cast(z.e)}; + new trigger::Trigger(when_cond, when_body, timeout_body, frame[z.v1].double_val, f, z.v4, z.loc); op When type VVVC -eval auto when_body = new ZAMResumption(this, z.v1); - auto timeout_body = new ZAMResumption(this, z.v2); - new trigger::Trigger(z.e, when_body, timeout_body, z.c.double_val, f, z.v3, z.loc); +eval auto when_body = make_intrusive(this, z.v1); + auto timeout_body = make_intrusive(this, z.v2); + ExprPtr when_cond = {NewRef{}, const_cast(z.e)}; + new trigger::Trigger(when_cond, when_body, timeout_body, z.c.double_val, f, z.v3, z.loc); op When type VV -eval auto when_body = new ZAMResumption(this, z.v2); - new trigger::Trigger(z.e, when_body, nullptr, -1.0, f, z.v1, z.loc); +eval auto when_body = make_intrusive(this, z.v2); + ExprPtr when_cond = {NewRef{}, const_cast(z.e)}; + new trigger::Trigger(when_cond, when_body, nullptr, -1.0, f, z.v1, z.loc); op CheckAnyLen op1-read diff --git a/src/script_opt/ZAM/Stmt.cc b/src/script_opt/ZAM/Stmt.cc index f6af240c95..7a4a46b4dd 100644 --- a/src/script_opt/ZAM/Stmt.cc +++ b/src/script_opt/ZAM/Stmt.cc @@ -1115,7 +1115,7 @@ const ZAMStmt ZAMCompiler::CompileWhen(const WhenStmt* ws) z.v1 = is_return; } - z.e = cond; + z.e = cond.get(); auto when_eval = AddInst(z); diff --git a/src/script_opt/ZAM/Vars.cc b/src/script_opt/ZAM/Vars.cc index 56ca59318d..543b32e32d 100644 --- a/src/script_opt/ZAM/Vars.cc +++ b/src/script_opt/ZAM/Vars.cc @@ -24,7 +24,7 @@ bool ZAMCompiler::IsUnused(const IDPtr& id, const Stmt* where) const return ! usage || ! usage->HasID(id.get()); } -void ZAMCompiler::LoadParam(ID* id) +void ZAMCompiler::LoadParam(const ID* id) { if ( id->IsType() ) reporter->InternalError( @@ -45,7 +45,7 @@ void ZAMCompiler::LoadParam(ID* id) (void)AddInst(z); } -const ZAMStmt ZAMCompiler::LoadGlobal(ID* id) +const ZAMStmt ZAMCompiler::LoadGlobal(const ID* id) { ZOp op; @@ -69,7 +69,7 @@ const ZAMStmt ZAMCompiler::LoadGlobal(ID* id) return AddInst(z); } -int ZAMCompiler::AddToFrame(ID* id) +int ZAMCompiler::AddToFrame(const ID* id) { frame_layout1[id] = frame_sizeI; frame_denizens.push_back(id); diff --git a/src/script_opt/ZAM/ZInst.h b/src/script_opt/ZAM/ZInst.h index b1dc751ea1..f9fc548804 100644 --- a/src/script_opt/ZAM/ZInst.h +++ b/src/script_opt/ZAM/ZInst.h @@ -18,7 +18,7 @@ class Stmt; using AttributesPtr = IntrusivePtr; // Maps ZAM frame slots to associated identifiers. -using FrameMap = std::vector; +using FrameMap = std::vector; // Maps ZAM frame slots to information for sharing the slot across // multiple script variables. @@ -28,7 +28,7 @@ public: // The variables sharing the slot. ID's need to be non-const so we // can manipulate them, for example by changing their interpreter // frame offset. - std::vector ids; + std::vector ids; // A parallel vector, only used for fully compiled code, which // gives the names of the identifiers. When in use, the above @@ -402,7 +402,7 @@ public: TypePtr* types = nullptr; // Used for accessing function names. - ID* id_val = nullptr; + const ID* id_val = nullptr; // Whether the instruction can lead to globals changing. // Currently only needed by the optimizer, but convenient From 98cd3f221326c79fad890e7eb67eedb353fa08ce Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 14:53:33 -0800 Subject: [PATCH 06/10] update uses of "when" in base scripts to include captures --- scripts/base/frameworks/notice/actions/pp-alarms.zeek | 6 +++--- scripts/base/frameworks/openflow/plugins/ryu.zeek | 6 +++--- scripts/base/frameworks/sumstats/non-cluster.zeek | 4 ++-- scripts/base/protocols/ssl/main.zeek | 2 +- scripts/base/utils/active-http.zeek | 2 +- scripts/base/utils/dir.zeek | 2 +- scripts/base/utils/exec.zeek | 2 +- scripts/policy/frameworks/files/detect-MHR.zeek | 2 +- .../policy/frameworks/notice/extend-email/hostnames.zeek | 4 ++-- scripts/policy/frameworks/software/vulnerable.zeek | 2 +- scripts/policy/protocols/conn/known-hosts.zeek | 2 +- scripts/policy/protocols/conn/known-services.zeek | 2 +- scripts/policy/protocols/ssh/interesting-hostnames.zeek | 2 +- scripts/policy/protocols/ssl/known-certs.zeek | 2 +- scripts/policy/protocols/ssl/notary.zeek | 2 +- 15 files changed, 21 insertions(+), 21 deletions(-) diff --git a/scripts/base/frameworks/notice/actions/pp-alarms.zeek b/scripts/base/frameworks/notice/actions/pp-alarms.zeek index 8a26d57ec1..450dbc6872 100644 --- a/scripts/base/frameworks/notice/actions/pp-alarms.zeek +++ b/scripts/base/frameworks/notice/actions/pp-alarms.zeek @@ -212,7 +212,7 @@ function pretty_print_alarm(out: file, n: Info) return; } - when ( local h1name = lookup_addr(h1) ) + when [out, n, h1, h2, line1, line2, line3] ( local h1name = lookup_addr(h1) ) { if ( h2 == 0.0.0.0 ) { @@ -220,7 +220,7 @@ function pretty_print_alarm(out: file, n: Info) return; } - when ( local h2name = lookup_addr(h2) ) + when [out, n, h1, h2, line1, line2, line3, h1name] ( local h2name = lookup_addr(h2) ) { do_msg(out, n, line1, line2, line3, h1, h1name, h2, h2name); return; @@ -240,7 +240,7 @@ function pretty_print_alarm(out: file, n: Info) return; } - when ( local h2name_ = lookup_addr(h2) ) + when [out, n, h1, h2, line1, line2, line3] ( local h2name_ = lookup_addr(h2) ) { do_msg(out, n, line1, line2, line3, h1, "(dns timeout)", h2, h2name_); return; diff --git a/scripts/base/frameworks/openflow/plugins/ryu.zeek b/scripts/base/frameworks/openflow/plugins/ryu.zeek index 08e8c8d022..ef108e0a6a 100644 --- a/scripts/base/frameworks/openflow/plugins/ryu.zeek +++ b/scripts/base/frameworks/openflow/plugins/ryu.zeek @@ -135,9 +135,9 @@ function ryu_flow_mod(state: OpenFlow::ControllerState, match: ofp_match, flow_m ); # Execute call to Ryu's ReST API - when ( local result = ActiveHTTP::request(request) ) + when [state, match, flow_mod, request] ( local result = ActiveHTTP::request(request) ) { - if(result$code == 200) + if (result$code == 200) event OpenFlow::flow_mod_success(state$_name, match, flow_mod, result$body); else { @@ -165,7 +165,7 @@ function ryu_flow_clear(state: OpenFlow::ControllerState): bool $method="DELETE" ); - when ( local result = ActiveHTTP::request(request) ) + when [request] ( local result = ActiveHTTP::request(request) ) { } diff --git a/scripts/base/frameworks/sumstats/non-cluster.zeek b/scripts/base/frameworks/sumstats/non-cluster.zeek index c905d56e37..3059d78f26 100644 --- a/scripts/base/frameworks/sumstats/non-cluster.zeek +++ b/scripts/base/frameworks/sumstats/non-cluster.zeek @@ -74,7 +74,7 @@ function data_added(ss: SumStat, key: Key, result: Result) function request(ss_name: string): ResultTable { # This only needs to be implemented this way for cluster compatibility. - return when ( T ) + return when [ss_name] ( T ) { if ( ss_name in result_store ) return result_store[ss_name]; @@ -86,7 +86,7 @@ function request(ss_name: string): ResultTable function request_key(ss_name: string, key: Key): Result { # This only needs to be implemented this way for cluster compatibility. - return when ( T ) + return when [ss_name, key] ( T ) { if ( ss_name in result_store && key in result_store[ss_name] ) return result_store[ss_name][key]; diff --git a/scripts/base/protocols/ssl/main.zeek b/scripts/base/protocols/ssl/main.zeek index 37a60a1aff..2b610707e3 100644 --- a/scripts/base/protocols/ssl/main.zeek +++ b/scripts/base/protocols/ssl/main.zeek @@ -225,7 +225,7 @@ function log_record(info: Info) } else { - when ( |info$delay_tokens| == 0 ) + when [info] ( |info$delay_tokens| == 0 ) { log_record(info); } diff --git a/scripts/base/utils/active-http.zeek b/scripts/base/utils/active-http.zeek index 5d820b2f82..ed0210ccb6 100644 --- a/scripts/base/utils/active-http.zeek +++ b/scripts/base/utils/active-http.zeek @@ -98,7 +98,7 @@ function request(req: Request): ActiveHTTP::Response local cmd = request2curl(req, bodyfile, headersfile); local stdin_data = req?$client_data ? req$client_data : ""; - return when ( local result = Exec::run([$cmd=cmd, $stdin=stdin_data, $read_files=set(bodyfile, headersfile)]) ) + return when [req, resp, cmd, stdin_data, bodyfile, headersfile] ( local result = Exec::run([$cmd=cmd, $stdin=stdin_data, $read_files=set(bodyfile, headersfile)]) ) { # If there is no response line then nothing else will work either. if ( ! (result?$files && headersfile in result$files) ) diff --git a/scripts/base/utils/dir.zeek b/scripts/base/utils/dir.zeek index 678e81d7ed..dacba1ca2a 100644 --- a/scripts/base/utils/dir.zeek +++ b/scripts/base/utils/dir.zeek @@ -28,7 +28,7 @@ event Dir::monitor_ev(dir: string, last_files: set[string], callback: function(fname: string), poll_interval: interval) { - when ( local result = Exec::run([$cmd=fmt("ls -1 %s/", safe_shell_quote(dir))]) ) + when [dir, last_files, callback, poll_interval] ( local result = Exec::run([$cmd=fmt("ls -1 %s/", safe_shell_quote(dir))]) ) { if ( result$exit_code != 0 ) { diff --git a/scripts/base/utils/exec.zeek b/scripts/base/utils/exec.zeek index 85500bf9c2..7f87bb7bb4 100644 --- a/scripts/base/utils/exec.zeek +++ b/scripts/base/utils/exec.zeek @@ -178,7 +178,7 @@ function run(cmd: Command): Result $want_record=F, $config=config_strings]); - return when ( cmd$uid !in pending_commands ) + return when [cmd] ( cmd$uid !in pending_commands ) { local result = results[cmd$uid]; delete results[cmd$uid]; diff --git a/scripts/policy/frameworks/files/detect-MHR.zeek b/scripts/policy/frameworks/files/detect-MHR.zeek index 52f8dd7355..aa632a778d 100644 --- a/scripts/policy/frameworks/files/detect-MHR.zeek +++ b/scripts/policy/frameworks/files/detect-MHR.zeek @@ -39,7 +39,7 @@ function do_mhr_lookup(hash: string, fi: Notice::FileInfo) { local hash_domain = fmt("%s.malware.hash.cymru.com", hash); - when ( local MHR_result = lookup_hostname_txt(hash_domain) ) + when [hash, fi, hash_domain] ( local MHR_result = lookup_hostname_txt(hash_domain) ) { # Data is returned as " " local MHR_answer = split_string1(MHR_result, / /); diff --git a/scripts/policy/frameworks/notice/extend-email/hostnames.zeek b/scripts/policy/frameworks/notice/extend-email/hostnames.zeek index f6ed1a58be..f27477cb2d 100644 --- a/scripts/policy/frameworks/notice/extend-email/hostnames.zeek +++ b/scripts/policy/frameworks/notice/extend-email/hostnames.zeek @@ -33,7 +33,7 @@ hook notice(n: Notice::Info) &priority=-1 if ( n?$src ) { add n$email_delay_tokens["hostnames-src"]; - when ( local src_name = lookup_addr(n$src) ) + when [n, uid, output] ( local src_name = lookup_addr(n$src) ) { output = string_cat("orig/src hostname: ", src_name, "\n"); tmp_notice_storage[uid]$email_body_sections += output; @@ -43,7 +43,7 @@ hook notice(n: Notice::Info) &priority=-1 if ( n?$dst ) { add n$email_delay_tokens["hostnames-dst"]; - when ( local dst_name = lookup_addr(n$dst) ) + when [n, uid, output] ( local dst_name = lookup_addr(n$dst) ) { output = string_cat("resp/dst hostname: ", dst_name, "\n"); tmp_notice_storage[uid]$email_body_sections += output; diff --git a/scripts/policy/frameworks/software/vulnerable.zeek b/scripts/policy/frameworks/software/vulnerable.zeek index b8d8c43a12..40e48ffc40 100644 --- a/scripts/policy/frameworks/software/vulnerable.zeek +++ b/scripts/policy/frameworks/software/vulnerable.zeek @@ -82,7 +82,7 @@ event grab_vulnerable_versions(i: count) return; } - when ( local result = lookup_hostname_txt(cat(i,".",vulnerable_versions_update_endpoint)) ) + when [i] ( local result = lookup_hostname_txt(cat(i,".",vulnerable_versions_update_endpoint)) ) { local parts = split_string1(result, /\x09/); if ( |parts| != 2 ) #failure or end of list! diff --git a/scripts/policy/protocols/conn/known-hosts.zeek b/scripts/policy/protocols/conn/known-hosts.zeek index 279fa11917..4bd123abdd 100644 --- a/scripts/policy/protocols/conn/known-hosts.zeek +++ b/scripts/policy/protocols/conn/known-hosts.zeek @@ -77,7 +77,7 @@ event Known::host_found(info: HostsInfo) if ( ! Known::use_host_store ) return; - when ( local r = Broker::put_unique(Known::host_store$store, info$host, + when [info] ( local r = Broker::put_unique(Known::host_store$store, info$host, T, Known::host_store_expiry) ) { if ( r$status == Broker::SUCCESS ) diff --git a/scripts/policy/protocols/conn/known-services.zeek b/scripts/policy/protocols/conn/known-services.zeek index 313c49b940..a073d4d92a 100644 --- a/scripts/policy/protocols/conn/known-services.zeek +++ b/scripts/policy/protocols/conn/known-services.zeek @@ -123,7 +123,7 @@ event service_info_commit(info: ServicesInfo) { local key = AddrPortServTriplet($host = info$host, $p = info$port_num, $serv = s); - when ( local r = Broker::put_unique(Known::service_store$store, key, + when [info, s, key] ( local r = Broker::put_unique(Known::service_store$store, key, T, Known::service_store_expiry) ) { if ( r$status == Broker::SUCCESS ) diff --git a/scripts/policy/protocols/ssh/interesting-hostnames.zeek b/scripts/policy/protocols/ssh/interesting-hostnames.zeek index db80d7c6ac..2270b049e5 100644 --- a/scripts/policy/protocols/ssh/interesting-hostnames.zeek +++ b/scripts/policy/protocols/ssh/interesting-hostnames.zeek @@ -29,7 +29,7 @@ export { function check_ssh_hostname(id: conn_id, uid: string, host: addr) { - when ( local hostname = lookup_addr(host) ) + when [id, uid, host] ( local hostname = lookup_addr(host) ) { if ( interesting_hostnames in hostname ) { diff --git a/scripts/policy/protocols/ssl/known-certs.zeek b/scripts/policy/protocols/ssl/known-certs.zeek index 35fbcf0f7b..cd4fa23ccd 100644 --- a/scripts/policy/protocols/ssl/known-certs.zeek +++ b/scripts/policy/protocols/ssl/known-certs.zeek @@ -89,7 +89,7 @@ event Known::cert_found(info: CertsInfo, hash: string) local key = AddrCertHashPair($host = info$host, $hash = hash); - when ( local r = Broker::put_unique(Known::cert_store$store, key, + when [info, key] ( local r = Broker::put_unique(Known::cert_store$store, key, T, Known::cert_store_expiry) ) { if ( r$status == Broker::SUCCESS ) diff --git a/scripts/policy/protocols/ssl/notary.zeek b/scripts/policy/protocols/ssl/notary.zeek index 67f8734d41..0fc7f07c03 100644 --- a/scripts/policy/protocols/ssl/notary.zeek +++ b/scripts/policy/protocols/ssl/notary.zeek @@ -63,7 +63,7 @@ event ssl_established(c: connection) &priority=3 if ( waits_already ) return; - when ( local str = lookup_hostname_txt(fmt("%s.%s", digest, domain)) ) + when [digest] ( local str = lookup_hostname_txt(fmt("%s.%s", digest, domain)) ) { notary_cache[digest] = []; From b59ee83979ce76c3a6c72bf9cb3ebebf1be30d02 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 14:54:06 -0800 Subject: [PATCH 07/10] update existing test suite usage of "when" statements to include captures --- testing/btest/broker/store/create-failure.zeek | 2 +- testing/btest/broker/store/invalid-handle.zeek | 2 +- testing/btest/core/when-interpreter-exceptions.zeek | 8 ++++---- .../btest/scripts/base/frameworks/sumstats/on-demand.zeek | 2 +- testing/btest/scripts/base/utils/active-http.test | 2 +- testing/btest/scripts/base/utils/exec.test | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/testing/btest/broker/store/create-failure.zeek b/testing/btest/broker/store/create-failure.zeek index 57f56b815c..0b0e2fbca0 100644 --- a/testing/btest/broker/store/create-failure.zeek +++ b/testing/btest/broker/store/create-failure.zeek @@ -46,7 +46,7 @@ function check_terminate_conditions() function check_it(name: string, s: opaque of Broker::Store) { - when ( local r = Broker::keys(s) ) + when [name, s] ( local r = Broker::keys(s) ) { check_terminate_conditions(); print fmt("%s keys result: %s", name, r); diff --git a/testing/btest/broker/store/invalid-handle.zeek b/testing/btest/broker/store/invalid-handle.zeek index 3b270fa945..736962ba7b 100644 --- a/testing/btest/broker/store/invalid-handle.zeek +++ b/testing/btest/broker/store/invalid-handle.zeek @@ -8,7 +8,7 @@ function print_keys(a: any) { - when ( local s = Broker::keys(a) ) + when [a] ( local s = Broker::keys(a) ) { print "keys", s; } diff --git a/testing/btest/core/when-interpreter-exceptions.zeek b/testing/btest/core/when-interpreter-exceptions.zeek index d9d37e7318..36d63396bb 100644 --- a/testing/btest/core/when-interpreter-exceptions.zeek +++ b/testing/btest/core/when-interpreter-exceptions.zeek @@ -37,7 +37,7 @@ function f(do_exception: bool): bool local cmd = Exec::Command($cmd=fmt("echo 'f(%s)'", do_exception)); - return when ( local result = Exec::run(cmd) ) + return when [cmd, do_exception] ( local result = Exec::run(cmd) ) { print result$stdout; @@ -58,7 +58,7 @@ function g(do_exception: bool): bool { local stall = Exec::Command($cmd="sleep 30"); - return when ( local result = Exec::run(stall) ) + return when [do_exception, stall] ( local result = Exec::run(stall) ) { print "shouldn't get here, g()", do_exception, result; } @@ -84,14 +84,14 @@ event zeek_init() local cmd = Exec::Command($cmd="echo 'zeek_init()'"); local stall = Exec::Command($cmd="sleep 30"); - when ( local result = Exec::run(cmd) ) + when [cmd] ( local result = Exec::run(cmd) ) { print result$stdout; event termination_check(); print myrecord$notset; } - when ( local result2 = Exec::run(stall) ) + when [stall] ( local result2 = Exec::run(stall) ) { print "shouldn't get here", result2; check_term_condition(); diff --git a/testing/btest/scripts/base/frameworks/sumstats/on-demand.zeek b/testing/btest/scripts/base/frameworks/sumstats/on-demand.zeek index 208d9248f2..060552c0bd 100644 --- a/testing/btest/scripts/base/frameworks/sumstats/on-demand.zeek +++ b/testing/btest/scripts/base/frameworks/sumstats/on-demand.zeek @@ -22,7 +22,7 @@ redef exit_only_after_terminate=T; event on_demand_key() { local host = 1.2.3.4; - when ( local result = SumStats::request_key("test", [$host=host]) ) + when [host] ( local result = SumStats::request_key("test", [$host=host]) ) { print fmt("Key request for %s", host); print fmt(" Host: %s -> %.0f", host, result["test.reducer"]$sum); diff --git a/testing/btest/scripts/base/utils/active-http.test b/testing/btest/scripts/base/utils/active-http.test index b325bb40cc..d3fe3ebeac 100644 --- a/testing/btest/scripts/base/utils/active-http.test +++ b/testing/btest/scripts/base/utils/active-http.test @@ -23,7 +23,7 @@ function check_exit_condition() function test_request(label: string, req: ActiveHTTP::Request) { - when ( local response = ActiveHTTP::request(req) ) + when [label, req] ( local response = ActiveHTTP::request(req) ) { print label, response; check_exit_condition(); diff --git a/testing/btest/scripts/base/utils/exec.test b/testing/btest/scripts/base/utils/exec.test index 80a98c8285..7cf64908ae 100644 --- a/testing/btest/scripts/base/utils/exec.test +++ b/testing/btest/scripts/base/utils/exec.test @@ -19,7 +19,7 @@ function check_exit_condition() function test_cmd(label: string, cmd: Exec::Command) { - when ( local result = Exec::run(cmd) ) + when [label, cmd] ( local result = Exec::run(cmd) ) { local file_content = ""; From c5ab91671022ed4f8fe6d8bfa3842666fe668fb3 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 14:54:47 -0800 Subject: [PATCH 08/10] tests for new "when" semantics/errors --- .../language.when-aggregates/zeek..stdout | 5 + .../Baseline/language.when-capture-errors/out | 8 ++ testing/btest/language/when-aggregates.zeek | 62 ++++++++++++ .../btest/language/when-capture-errors.zeek | 94 +++++++++++++++++++ 4 files changed, 169 insertions(+) create mode 100644 testing/btest/Baseline/language.when-aggregates/zeek..stdout create mode 100644 testing/btest/Baseline/language.when-capture-errors/out create mode 100644 testing/btest/language/when-aggregates.zeek create mode 100644 testing/btest/language/when-capture-errors.zeek diff --git a/testing/btest/Baseline/language.when-aggregates/zeek..stdout b/testing/btest/Baseline/language.when-aggregates/zeek..stdout new file mode 100644 index 0000000000..06f927b0a8 --- /dev/null +++ b/testing/btest/Baseline/language.when-aggregates/zeek..stdout @@ -0,0 +1,5 @@ +### BTest baseline data generated by btest-diff. Do not edit. Use "btest -U/-u" to update. Requires BTest >= 0.63. +[x=1, y=2], [x=11, y=12] +[x=1, y=55], [x=11, y=12] +[x=111, y=222], [x=13, y=14] +[x=99, y=222], [x=13, y=14] diff --git a/testing/btest/Baseline/language.when-capture-errors/out b/testing/btest/Baseline/language.when-capture-errors/out new file mode 100644 index 0000000000..aaa6da7d71 --- /dev/null +++ b/testing/btest/Baseline/language.when-capture-errors/out @@ -0,0 +1,8 @@ +### BTest baseline data generated by btest-diff. Do not edit. Use "btest -U/-u" to update. Requires BTest >= 0.63. +warning in <...>/when-capture-errors.zeek, lines 19-22: "when" statement referring to locals without an explicit [] capture is deprecated: orig1 (when (0 < g) { print orig1}) +warning in <...>/when-capture-errors.zeek, lines 25-28: "when" statement referring to locals without an explicit [] capture is deprecated: orig3 (when (0 < g || orig3) { print g}) +warning in <...>/when-capture-errors.zeek, lines 34-38: "when" statement referring to locals without an explicit [] capture is deprecated: orig1 (when (0 < g) { print g} timeout 1.0 sec { print orig1}) +error in <...>/when-capture-errors.zeek, lines 66-70: orig2 is used inside "when" statement but not captured (when (0 < g) { print orig1} timeout 1.0 sec { print orig2}) +error in <...>/when-capture-errors.zeek, lines 76-80: orig3 is captured but not used inside "when" statement (when (0 < g) { print orig1} timeout 1.0 sec { print orig2}) +error in <...>/when-capture-errors.zeek, line 83: no such local identifier: l1 +error in <...>/when-capture-errors.zeek, line 89: no such local identifier: l2 diff --git a/testing/btest/language/when-aggregates.zeek b/testing/btest/language/when-aggregates.zeek new file mode 100644 index 0000000000..7462edd476 --- /dev/null +++ b/testing/btest/language/when-aggregates.zeek @@ -0,0 +1,62 @@ +# @TEST-EXEC: btest-bg-run zeek zeek -b %INPUT +# @TEST-EXEC: btest-bg-wait 15 +# @TEST-EXEC: btest-diff zeek/.stdout + +type r: record { x: int; y: int; }; + +global g = 0; + +function async_foo1(arg: r) : r + { + return when ( g > 0 ) + { + arg$x = 99; + return r($x = 11, $y = 12); + } + } + +function async_foo2(arg: r) : r + { + return when [arg] ( g > 0 ) + { + arg$x = 99; + return r($x = 13, $y = 14); + } + } + +event zeek_init() + { + local orig1 = r($x = 1, $y = 2); + local orig2 = copy(orig1); + + when ( local resp1 = async_foo1(orig1) ) + { + print orig1, resp1; + } + + when [orig2] ( local resp2 = async_foo1(orig2) ) + { + print orig2, resp2; + } + + local orig3 = r($x = 111, $y = 222); + local orig4 = copy(orig3); + + when ( local resp4 = async_foo2(orig3) ) + { + print orig3, resp4; + } + + when [orig4] ( local resp5 = async_foo2(orig4) ) + { + print orig4, resp5; + } + + orig1$y = 44; + orig2$y = 55; + } + +event zeek_init() &priority=-10 + { + g = 1; + } diff --git a/testing/btest/language/when-capture-errors.zeek b/testing/btest/language/when-capture-errors.zeek new file mode 100644 index 0000000000..85b0abb393 --- /dev/null +++ b/testing/btest/language/when-capture-errors.zeek @@ -0,0 +1,94 @@ +# @TEST-EXEC-FAIL: zeek -b %INPUT >out 2>&1 +# @TEST-EXEC: TEST_DIFF_CANONIFIER=$SCRIPTS/diff-remove-abspath btest-diff out + +global g = 0; + +event zeek_init() + { + local orig1 = "hello"; + local orig2 = 3.5; + local orig3 = F; + + # Should be okay since no local captures. + when ( g > 0 ) + { + print g; + } + + # Should generate a deprecation warning. + when ( g > 0 ) + { + print orig1; + } + + # Same. + when ( g > 0 || orig3 ) + { + print g; + } + + # Same. + when ( g > 0 ) + { + print g; + } + timeout 1 sec + { + print orig1; + } + + # Should be okay. + when [orig2] ( g > 0 && orig2 < 10.0 ) + { + print g; + } + + # Should be okay. + when [orig1] ( g > 0 ) + { + print orig1; + } + + # Should be okay. + when [orig1] ( g > 0 ) + { + print g; + } + timeout 1 sec + { + print orig1; + } + + # Mismatch: missing a local. + when [orig1] ( g > 0 ) + { + print orig1; + } + timeout 1 sec + { + print orig2; + } + + # Mismatch: overspecifies a local. + when [orig1, orig2, orig3] ( g > 0 ) + { + print orig1; + } + timeout 1 sec + { + print orig2; + } + + # Should generate a "no such identifier" error. + when [l1] ( local l1 = network_time() ) + { + print l1; + } + + # As should this. + when [l2] ( g > 0 ) + { + local l2 = network_time(); + print l2; + } + } From fa848167bb261a799ccabc59bc4e41a76212f17a Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Fri, 7 Jan 2022 21:52:11 -0800 Subject: [PATCH 09/10] attempt to make "when" btest deterministic --- testing/btest/language/when-aggregates.zeek | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/testing/btest/language/when-aggregates.zeek b/testing/btest/language/when-aggregates.zeek index 7462edd476..6e9a434aee 100644 --- a/testing/btest/language/when-aggregates.zeek +++ b/testing/btest/language/when-aggregates.zeek @@ -29,25 +29,28 @@ event zeek_init() local orig1 = r($x = 1, $y = 2); local orig2 = copy(orig1); - when ( local resp1 = async_foo1(orig1) ) + when ( g == 1 && local resp1 = async_foo1(orig1) ) { + ++g; print orig1, resp1; } - when [orig2] ( local resp2 = async_foo1(orig2) ) + when [orig2] ( g == 2 && local resp2 = async_foo1(orig2) ) { + ++g; print orig2, resp2; } local orig3 = r($x = 111, $y = 222); local orig4 = copy(orig3); - when ( local resp4 = async_foo2(orig3) ) + when ( g == 3 && local resp4 = async_foo2(orig3) ) { + ++g; print orig3, resp4; } - when [orig4] ( local resp5 = async_foo2(orig4) ) + when [orig4] ( g == 4 && local resp5 = async_foo2(orig4) ) { print orig4, resp5; } From 98a05538b743a97ba69aed975b31aa1ddf09c580 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Sat, 8 Jan 2022 08:24:15 -0800 Subject: [PATCH 10/10] explicitly provide the frame for evaluating a "when" timeout expression --- src/Trigger.cc | 8 ++++---- src/Trigger.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Trigger.cc b/src/Trigger.cc index 9d4e2ea121..0ce780ef38 100644 --- a/src/Trigger.cc +++ b/src/Trigger.cc @@ -102,7 +102,7 @@ protected: Trigger::Trigger(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, ExprPtr timeout_expr, Frame* frame, bool is_return, const Location* location) { - GetTimeout(timeout_expr); + GetTimeout(timeout_expr, frame); Init(cond, body, timeout_stmts, frame, is_return, location); } @@ -120,12 +120,12 @@ Trigger::Trigger(WhenInfo* wi, const IDSet& _globals, std::vector _local local_aggrs = std::move(_local_aggrs); have_trigger_elems = true; - GetTimeout(wi->TimeoutExpr()); + GetTimeout(wi->TimeoutExpr(), f); Init(wi->Cond(), wi->WhenBody(), wi->TimeoutStmt(), f, wi->IsReturn(), loc); } -void Trigger::GetTimeout(const ExprPtr& timeout_expr) +void Trigger::GetTimeout(const ExprPtr& timeout_expr, Frame* f) { timeout_value = -1.0; @@ -135,7 +135,7 @@ void Trigger::GetTimeout(const ExprPtr& timeout_expr) try { - timeout_val = timeout_expr->Eval(frame); + timeout_val = timeout_expr->Eval(f); } catch ( InterpreterException& ) { /* Already reported */ diff --git a/src/Trigger.h b/src/Trigger.h index 6810a7fccd..833b3f05bc 100644 --- a/src/Trigger.h +++ b/src/Trigger.h @@ -112,7 +112,7 @@ public: private: friend class TriggerTimer; - void GetTimeout(const ExprPtr& timeout_expr); + void GetTimeout(const ExprPtr& timeout_expr, Frame* f); void Init(ExprPtr cond, StmtPtr body, StmtPtr timeout_stmts, Frame* frame, bool is_return, const Location* location);