diff --git a/src/Expr.h b/src/Expr.h index 6838955ffa..5bb2af6844 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -519,7 +519,7 @@ public: ValPtr Eval(Frame* f) const override; }; -class IndexExpr final : public BinaryExpr { +class IndexExpr : public BinaryExpr { public: IndexExpr(ExprPtr op1, ListExprPtr op2, bool is_slice = false); @@ -549,6 +549,39 @@ protected: bool is_slice; }; +class IndexExprWhen final : public IndexExpr { +public: + static inline std::vector results = {}; + static inline int evaluating = 0; + + static void StartEval() + { ++evaluating; } + + static void EndEval() + { --evaluating; } + + static std::vector TakeAllResults() + { + auto rval = std::move(results); + results = {}; + return rval; + } + + IndexExprWhen(ExprPtr op1, ListExprPtr op2, bool is_slice = false) + : IndexExpr(std::move(op1), std::move(op2), is_slice) + { } + + ValPtr Eval(Frame* f) const override + { + auto v = IndexExpr::Eval(f); + + if ( v && evaluating > 0 ) + results.emplace_back(v); + + return v; + } +}; + class FieldExpr final : public UnaryExpr { public: FieldExpr(ExprPtr op, const char* field_name); diff --git a/src/Trigger.cc b/src/Trigger.cc index 6a1820a40a..697019670f 100644 --- a/src/Trigger.cc +++ b/src/Trigger.cc @@ -56,24 +56,6 @@ TraversalCode trigger::TriggerTraversalCallback::PreExpr(const Expr* expr) break; }; - case EXPR_INDEX: - { - const auto* e = static_cast(expr); - Obj::SuppressErrors no_errors; - - try - { - auto v = e->Eval(trigger->frame); - - if ( v ) - trigger->Register(v.get()); - } - catch ( InterpreterException& ) - { /* Already reported */ } - - break; - } - default: // All others are uninteresting. break; @@ -217,12 +199,15 @@ Trigger::~Trigger() // point. } -void Trigger::Init() +void Trigger::Init(std::vector index_expr_results) { assert(! disabled); UnregisterAll(); TriggerTraversalCallback cb(this); cond->Traverse(&cb); + + for ( const auto& v : index_expr_results ) + Register(v.get()); } bool Trigger::Eval() @@ -265,6 +250,7 @@ bool Trigger::Eval() f->SetTrigger({NewRef{}, this}); ValPtr v; + IndexExprWhen::StartEval(); try { @@ -273,6 +259,9 @@ bool Trigger::Eval() catch ( InterpreterException& ) { /* Already reported */ } + IndexExprWhen::EndEval(); + auto index_expr_results = IndexExprWhen::TakeAllResults(); + f->ClearTrigger(); if ( f->HasDelayed() ) @@ -288,7 +277,7 @@ bool Trigger::Eval() // Not true. Perhaps next time... DBG_LOG(DBG_NOTIFIERS, "%s: trigger condition is false", Name()); Unref(f); - Init(); + Init(std::move(index_expr_results)); return false; } diff --git a/src/Trigger.h b/src/Trigger.h index 4aabedaa73..5a7839a7c4 100644 --- a/src/Trigger.h +++ b/src/Trigger.h @@ -91,7 +91,7 @@ private: friend class TriggerTraversalCallback; friend class TriggerTimer; - void Init(); + void Init(std::vector> index_expr_results); void Register(ID* id); void Register(Val* val); void UnregisterAll(); diff --git a/src/parse.y b/src/parse.y index d0eeb91300..ae940e6dec 100644 --- a/src/parse.y +++ b/src/parse.y @@ -58,7 +58,7 @@ %type init_class %type opt_init %type TOK_CONSTANT -%type expr opt_expr init anonymous_function lambda_body index_slice opt_deprecated +%type expr opt_expr init anonymous_function lambda_body index_slice opt_deprecated when_condition %type event %type stmt stmt_list func_body for_head %type type opt_type enum_body @@ -126,6 +126,7 @@ extern zeek::detail::Expr* g_curr_debug_expr; extern bool in_debug; extern const char* g_curr_debug_error; +static int in_when_cond = 0; static int in_hook = 0; int in_init = 0; int in_record = 0; @@ -289,6 +290,11 @@ opt_expr: { $$ = 0; } ; +when_condition: + { ++in_when_cond; } expr { --in_when_cond; } + { $$ = $2; } + ; + expr: '(' expr ')' { @@ -474,7 +480,10 @@ expr: | expr '[' expr_list ']' { zeek::detail::set_location(@1, @4); - $$ = new zeek::detail::IndexExpr({zeek::AdoptRef{}, $1}, {zeek::AdoptRef{}, $3}); + if ( in_when_cond > 0 ) + $$ = new zeek::detail::IndexExprWhen({zeek::AdoptRef{}, $1}, {zeek::AdoptRef{}, $3}); + else + $$ = new zeek::detail::IndexExpr({zeek::AdoptRef{}, $1}, {zeek::AdoptRef{}, $3}); } | index_slice @@ -1328,7 +1337,11 @@ index_slice: auto le = zeek::make_intrusive(std::move(low)); le->Append(std::move(high)); - $$ = new zeek::detail::IndexExpr({zeek::AdoptRef{}, $1}, std::move(le), true); + + if ( in_when_cond > 0 ) + $$ = new zeek::detail::IndexExprWhen({zeek::AdoptRef{}, $1}, std::move(le), true); + else + $$ = new zeek::detail::IndexExpr({zeek::AdoptRef{}, $1}, std::move(le), true); } opt_attr: @@ -1535,14 +1548,14 @@ stmt: zeek::detail::script_coverage_mgr.AddStmt($$); } - | TOK_WHEN '(' expr ')' stmt + | TOK_WHEN '(' when_condition ')' stmt { zeek::detail::set_location(@3, @5); $$ = new zeek::detail::WhenStmt({zeek::AdoptRef{}, $3}, {zeek::AdoptRef{}, $5}, nullptr, nullptr, false); } - | TOK_WHEN '(' expr ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' + | TOK_WHEN '(' when_condition ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' { zeek::detail::set_location(@3, @9); $$ = new zeek::detail::WhenStmt({zeek::AdoptRef{}, $3}, {zeek::AdoptRef{}, $5}, @@ -1552,14 +1565,14 @@ stmt: } - | TOK_RETURN TOK_WHEN '(' expr ')' stmt + | TOK_RETURN TOK_WHEN '(' when_condition ')' stmt { zeek::detail::set_location(@4, @6); $$ = new zeek::detail::WhenStmt({zeek::AdoptRef{}, $4}, {zeek::AdoptRef{}, $6}, nullptr, nullptr, true); } - | TOK_RETURN TOK_WHEN '(' expr ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' + | TOK_RETURN TOK_WHEN '(' when_condition ')' stmt TOK_TIMEOUT expr '{' opt_no_test_block stmt_list '}' { zeek::detail::set_location(@4, @10); $$ = new zeek::detail::WhenStmt({zeek::AdoptRef{}, $4}, {zeek::AdoptRef{}, $6}, diff --git a/testing/btest/Baseline/language.when-order-of-eval/zeek..stdout b/testing/btest/Baseline/language.when-order-of-eval/zeek..stdout new file mode 100644 index 0000000000..21c9d90da6 --- /dev/null +++ b/testing/btest/Baseline/language.when-order-of-eval/zeek..stdout @@ -0,0 +1,9 @@ +running myevent, 1 +running myevent, 2 +running myevent, 3 +running myevent, 4 +running myevent, 5 +triggered when condition against 'x' +running myevent, 6 +triggered when condition against 'y' +running myevent, 7 diff --git a/testing/btest/language/when-order-of-eval.zeek b/testing/btest/language/when-order-of-eval.zeek new file mode 100644 index 0000000000..05835d6141 --- /dev/null +++ b/testing/btest/language/when-order-of-eval.zeek @@ -0,0 +1,68 @@ +# @TEST-EXEC: btest-bg-run zeek zeek -b %INPUT +# @TEST-EXEC: btest-bg-wait 10 +# @TEST-EXEC: btest-diff zeek/.stdout + +# The 'when' implementation historically performed an AST-traversal to locate +# any index-expressions like `x[9]` and evaluated them so that it could +# register the assocated value as something for which it needs to receive +# "modification" notifications. +# +# Evaluating arbitrary expressions during an AST-traversal like that ignores +# the typical order-of-evaluation/short-circuiting you'd expect if the +# condition was evaluated normally, from its root expression. This test is +# checking that evaluation of 'when' conditions behaves according to those +# usual expectations. + +redef exit_only_after_terminate = T; + +type r: record { + a: count; +}; + +global x: table[count] of count; +global y: table[count] of r; + +const event_interval = 0.05sec; + +function foo() + { + when ( 9 in y && y[9]$a == 3 ) + { + print "triggered when condition against 'y'"; + terminate(); + } + } + +function bar() + { + when ( 9 in x && x[9] > 3 ) + print "triggered when condition against 'x'"; + } + +global ev_count = 0; +event myevent() + { + ++ev_count; + print "running myevent", ev_count; + local init_at = 3; + + if ( ev_count == init_at ) + { + x[9] = 2; + y[9] = r($a = 0); + } + else if ( ev_count > init_at ) + { + ++x[9]; + ++y[9]$a; + } + + schedule event_interval { myevent() }; + } + +event zeek_init() + { + foo(); + bar(); + schedule event_interval { myevent() }; + }