diff --git a/src/script_opt/CSE.cc b/src/script_opt/CSE.cc index a19828c1f3..de2d542a40 100644 --- a/src/script_opt/CSE.cc +++ b/src/script_opt/CSE.cc @@ -12,6 +12,19 @@ CSE_ValidityChecker::CSE_ValidityChecker(std::shared_ptr _pfs, con start_e = _start_e; end_e = _end_e; + // For validity checking, if end_e is inside a loop and start_e is + // outside that loop, then we need to extend the checking beyond end_e + // to the end of the loop, to account for correctness after iterating + // through the loop. We do that as follows. Upon entering an outer + // loop, we set end_s to that loop. (We can tell it's an outer loop if, + // upon entering, end_s is nil.) (1) If we encounter end_e while inside + // that loop (which we can tell because end_s is non-nil), then we clear + // end_e to signal that we're now using end_s to terminate the traversal. + // (2) If we complete the loop without encountering end_e (which we can + // tell because after traversal end_e is non-nil), then we clear end_s + // to mark that the traversal is now not inside a loop. + end_s = nullptr; + // Track whether this is a record assignment, in which case // we're attuned to assignments to the same field for the // same type of record. @@ -38,6 +51,23 @@ TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) { return TC_ABORTALL; } + if ( (t == STMT_WHILE || t == STMT_FOR) && have_start_e && ! end_s ) + // We've started the traversal and are entering an outer loop. + end_s = s; + + return TC_CONTINUE; +} + +TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) { + if ( end_s == s ) { + if ( ! end_e ) + // We've done the outer loop containing the end expression. + return TC_ABORTALL; + + // We're no longer doing an outer loop. + end_s = nullptr; + } + return TC_CONTINUE; } @@ -60,8 +90,13 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) { ASSERT(! have_end_e); have_end_e = true; - // ... and we're now done. - return TC_ABORTALL; + if ( ! end_s ) + // We're now done. + return TC_ABORTALL; + + // Need to finish the loop before we mark things as done. + // Signal to the statement traversal that we're in that state. + end_e = nullptr; } if ( ! have_start_e ) diff --git a/src/script_opt/CSE.h b/src/script_opt/CSE.h index 60b1ad45bc..a078f059c6 100644 --- a/src/script_opt/CSE.h +++ b/src/script_opt/CSE.h @@ -21,6 +21,7 @@ public: const Expr* end_e); TraversalCode PreStmt(const Stmt*) override; + TraversalCode PostStmt(const Stmt*) override; TraversalCode PreExpr(const Expr*) override; TraversalCode PostExpr(const Expr*) override; @@ -88,9 +89,13 @@ protected: // assignment expression. const Expr* start_e; - // Where in the AST to end our analysis. + // Expression in the AST where we should end our analysis. See discussion + // in the constructor for the interplay between this and end_s. const Expr* end_e; + // Statement in the AST where we should end our analysis. + const Stmt* end_s; + // If what we're analyzing is a record element, then its offset. // -1 if not. int field; diff --git a/testing/btest/Baseline/opt.regress-aggr-change-in-loop/output b/testing/btest/Baseline/opt.regress-aggr-change-in-loop/output new file mode 100644 index 0000000000..0d711d7dd1 --- /dev/null +++ b/testing/btest/Baseline/opt.regress-aggr-change-in-loop/output @@ -0,0 +1,4 @@ +### BTest baseline data generated by btest-diff. Do not edit. Use "btest -U/-u" to update. Requires BTest >= 0.63. +[hash=bletch] +[hash=xyzzy] +done diff --git a/testing/btest/opt/regress-aggr-change-in-loop.zeek b/testing/btest/opt/regress-aggr-change-in-loop.zeek new file mode 100644 index 0000000000..2162128030 --- /dev/null +++ b/testing/btest/opt/regress-aggr-change-in-loop.zeek @@ -0,0 +1,38 @@ +# @TEST-DOC: Regression test for an aggregate in a CSE changing inside a loop +# @TEST-REQUIRES: test "${ZEEK_USE_CPP}" != "1" +# @TEST-EXEC: zeek -b -O ZAM %INPUT >output +# @TEST-EXEC: btest-diff output + +type Data: record { + hash: string; +}; + +global map: table[string] of Data; + +function traverse_map(hash: string) + { + local tmp = map[hash]; + + if ( tmp$hash == "" ) + return; + + while ( tmp$hash in map ) + { + # Prior to the fix, the value of tmp$hash computed in the + # earlier "if" statement was used here, rather than the + # optimizer recognizing that "tmp" can have changed at this + # point due to the loop, and thus that value can be stale. + # That led to an infinite loop here. + tmp = map[tmp$hash]; + print tmp; + } + } + +event zeek_init() + { + map["foo"] = Data($hash="bar"); + map["bar"] = Data($hash="bletch"); + map["bletch"] = Data($hash="xyzzy"); + traverse_map("foo"); + print "done"; + }