diff --git a/src/script_opt/GenIDDefs.cc b/src/script_opt/GenIDDefs.cc index ef7ca32704..6eaf77576d 100644 --- a/src/script_opt/GenIDDefs.cc +++ b/src/script_opt/GenIDDefs.cc @@ -107,37 +107,8 @@ TraversalCode GenIDDefs::PreStmt(const Stmt* s) } case STMT_SWITCH: - { - auto sw = s->AsSwitchStmt(); - auto e = sw->StmtExpr(); - - e->Traverse(this); - - StartConfluenceBlock(sw); - - for ( const auto& c : *sw->Cases() ) - { - auto body = c->Body(); - - auto exprs = c->ExprCases(); - if ( exprs ) - exprs->Traverse(this); - - auto type_ids = c->TypeCases(); - if ( type_ids ) - { - for ( const auto& id : *type_ids ) - if ( id->Name() ) - TrackID(id); - } - - body->Traverse(this); - } - - EndConfluenceBlock(sw->HasDefault()); - + AnalyzeSwitch(s->AsSwitchStmt()); return TC_ABORTSTMT; - } case STMT_FOR: { @@ -201,6 +172,38 @@ TraversalCode GenIDDefs::PreStmt(const Stmt* s) } } +void GenIDDefs::AnalyzeSwitch(const SwitchStmt* sw) + { + sw->StmtExpr()->Traverse(this); + + for ( const auto& c : *sw->Cases() ) + { + // Important: the confluence block is the switch statement + // itself, not the case body. This is needed so that variable + // assignments made inside case bodies that end with + // "fallthrough" are correctly propagated to the next case + // body. + StartConfluenceBlock(sw); + + auto body = c->Body(); + + auto exprs = c->ExprCases(); + if ( exprs ) + exprs->Traverse(this); + + auto type_ids = c->TypeCases(); + if ( type_ids ) + { + for ( const auto& id : *type_ids ) + if ( id->Name() ) + TrackID(id); + } + + body->Traverse(this); + EndConfluenceBlock(false); + } + } + TraversalCode GenIDDefs::PostStmt(const Stmt* s) { switch ( s->Tag() ) diff --git a/src/script_opt/GenIDDefs.h b/src/script_opt/GenIDDefs.h index 29593e458c..8667394b14 100644 --- a/src/script_opt/GenIDDefs.h +++ b/src/script_opt/GenIDDefs.h @@ -22,6 +22,8 @@ private: void TraverseFunction(const Func* f, ScopePtr scope, StmtPtr body); TraversalCode PreStmt(const Stmt*) override; + void AnalyzeSwitch(const SwitchStmt* sw); + TraversalCode PostStmt(const Stmt*) override; TraversalCode PreExpr(const Expr*) override; TraversalCode PostExpr(const Expr*) override; diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index 7da1343361..7cefdc63a2 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -320,13 +320,20 @@ IntrusivePtr Case::Duplicate() return make_intrusive(new_exprs, nullptr, s->Duplicate()); } + IDPList* new_type_cases = nullptr; + if ( type_cases ) { + new_type_cases = new IDPList(); + for ( auto tc : *type_cases ) + { zeek::Ref(tc); + new_type_cases->append(tc); + } } - return make_intrusive(nullptr, type_cases, s->Duplicate()); + return make_intrusive(nullptr, new_type_cases, s->Duplicate()); } StmtPtr SwitchStmt::Duplicate() @@ -354,6 +361,9 @@ bool SwitchStmt::IsReduced(Reducer* r) const if ( ! e->IsReduced(r) ) return NonReduced(e.get()); + if ( cases->length() == 0 ) + return false; + for ( const auto& c : *cases ) { if ( c->ExprCases() && ! c->ExprCases()->IsReduced(r) ) @@ -371,6 +381,10 @@ bool SwitchStmt::IsReduced(Reducer* r) const StmtPtr SwitchStmt::DoReduce(Reducer* rc) { + if ( cases->length() == 0 ) + // Degenerate. + return make_intrusive(); + auto s = make_intrusive(); StmtPtr red_e_stmt; @@ -388,7 +402,8 @@ StmtPtr SwitchStmt::DoReduce(Reducer* rc) for ( auto& i : case_label_type_list ) { IDPtr idp = {NewRef{}, i.first}; - i.first = rc->UpdateID(idp).release(); + if ( idp->Name() ) + i.first = rc->UpdateID(idp).release(); } for ( const auto& c : *cases ) @@ -405,7 +420,11 @@ StmtPtr SwitchStmt::DoReduce(Reducer* rc) auto c_t = c->TypeCases(); if ( c_t ) - rc->UpdateIDs(c_t); + { + for ( auto& c_t_i : *c_t ) + if ( c_t_i->Name() ) + c_t_i = rc->UpdateID({NewRef{}, c_t_i}).release(); + } c->UpdateBody(c->Body()->Reduce(rc)); }