diff --git a/src/Traverse.cc b/src/Traverse.cc index c97e3103d4..9d3e9f8fe4 100644 --- a/src/Traverse.cc +++ b/src/Traverse.cc @@ -2,27 +2,70 @@ #include "zeek/Traverse.h" +#include "zeek/Func.h" #include "zeek/Scope.h" -#include "zeek/Stmt.h" #include "zeek/input.h" +#include "zeek/3rdparty/doctest.h" + namespace zeek::detail { TraversalCode traverse_all(TraversalCallback* cb) { if ( ! global_scope() ) return TC_CONTINUE; - if ( ! stmts ) - // May be null when parsing fails. - return TC_CONTINUE; cb->current_scope = global_scope(); TraversalCode tc = global_scope()->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - tc = stmts->Traverse(cb); + + if ( stmts ) + // May be null when parsing fails. + tc = stmts->Traverse(cb); + HANDLE_TC_STMT_POST(tc); } } // namespace zeek::detail + + +TEST_SUITE_BEGIN("traverser"); + +namespace { +// Helper classes for tests below. +using namespace zeek::detail; + +class SaveRestoreStmts { +public: + SaveRestoreStmts() : orig(zeek::detail::stmts) {} + ~SaveRestoreStmts() { zeek::detail::stmts = orig; } + + Stmt* orig; +}; + +class ZeekInitFinder : public TraversalCallback { +public: + TraversalCode PreFunction(const zeek::Func* f) override { + if ( f->GetName() == "zeek_init" ) + zeek_init_found = true; + + return TC_CONTINUE; + } + + bool zeek_init_found = false; +}; + +} // namespace + +TEST_CASE("traverse_all") { + SUBCASE("ensure zeek_init() is found if stmts == nullptr") { + SaveRestoreStmts save_restore_stmts; + zeek::detail::stmts = nullptr; // force stmts to be a nullptr + ZeekInitFinder cb; + traverse_all(&cb); + CHECK(cb.zeek_init_found); + } +} + +TEST_SUITE_END();