diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 37f6109cd2..27f0834715 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -394,6 +394,7 @@ set(MAIN_SRCS script_opt/CPP/Util.cc script_opt/CPP/Vars.cc ${_gen_zeek_script_cpp} + script_opt/CSE.cc script_opt/Expr.cc script_opt/FuncInfo.cc script_opt/GenIDDefs.cc diff --git a/src/script_opt/CSE.cc b/src/script_opt/CSE.cc new file mode 100644 index 0000000000..56f3a697e1 --- /dev/null +++ b/src/script_opt/CSE.cc @@ -0,0 +1,266 @@ +// See the file "COPYING" in the main distribution directory for copyright. + +#include "zeek/script_opt/CSE.h" + +namespace zeek::detail { + +CSE_ValidityChecker::CSE_ValidityChecker(ProfileFuncs& _pfs, const std::vector& _ids, + const Expr* _start_e, const Expr* _end_e) + : pfs(_pfs), ids(_ids) { + start_e = _start_e; + end_e = _end_e; + + // Track whether this is a record assignment, in which case + // we're attuned to assignments to the same field for the + // same type of record. + if ( start_e->Tag() == EXPR_FIELD ) { + field = start_e->AsFieldExpr()->Field(); + + // Track the type of the record, too, so we don't confuse + // field references to different records that happen to + // have the same offset as potential aliases. + field_type = start_e->GetOp1()->GetType(); + } + + else + field = -1; // flags that there's no relevant field +} + +TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) { + auto t = s->Tag(); + + if ( t == STMT_WHEN ) { + // These are too hard to analyze - they result in lambda calls + // that can affect aggregates, etc. + is_valid = false; + return TC_ABORTALL; + } + + if ( t == STMT_ADD || t == STMT_DELETE ) + in_aggr_mod_stmt = true; + + return TC_CONTINUE; +} + +TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) { + if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE ) + in_aggr_mod_stmt = false; + + return TC_CONTINUE; +} + +TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) { + if ( e == start_e ) { + ASSERT(! have_start_e); + have_start_e = true; + + // Don't analyze the expression, as it's our starting + // point and we don't want to conflate its properties + // with those of any intervening expressions. + return TC_CONTINUE; + } + + if ( e == end_e ) { + if ( ! have_start_e ) + reporter->InternalError("CSE_ValidityChecker: saw end but not start"); + + ASSERT(! have_end_e); + have_end_e = true; + + // ... and we're now done. + return TC_ABORTALL; + } + + if ( ! have_start_e ) + // We don't yet have a starting point. + return TC_CONTINUE; + + // We have a starting point, and not yet an ending point. + auto t = e->Tag(); + + switch ( t ) { + case EXPR_ASSIGN: { + auto lhs_ref = e->GetOp1()->AsRefExprPtr(); + auto lhs = lhs_ref->GetOp1()->AsNameExpr(); + + if ( CheckID(lhs->Id(), false) ) + return TC_ABORTALL; + + // Note, we don't use CheckAggrMod() because this is a plain + // assignment. It might be changing a variable's binding to + // an aggregate ("aggr_var = new_aggr_val"), but we don't + // introduce temporaries that are simply aliases of existing + // variables (e.g., we don't have "::#8 = aggr_var"), + // and so there's no concern that the temporary could now be + // referring to the wrong aggregate. If instead we have + // "::#8 = aggr_var$foo", then a reassignment here + // to "aggr_var" will already be caught by CheckID(). + } break; + + case EXPR_INDEX_ASSIGN: { + auto lhs_aggr = e->GetOp1(); + auto lhs_aggr_id = lhs_aggr->AsNameExpr()->Id(); + + if ( CheckID(lhs_aggr_id, true) || CheckTableMod(lhs_aggr->GetType()) ) + return TC_ABORTALL; + } break; + + case EXPR_FIELD_LHS_ASSIGN: { + auto lhs = e->GetOp1(); + auto lhs_aggr_id = lhs->AsNameExpr()->Id(); + auto lhs_field = e->AsFieldLHSAssignExpr()->Field(); + + if ( CheckID(lhs_aggr_id, true) ) + return TC_ABORTALL; + if ( lhs_field == field && same_type(lhs_aggr_id->GetType(), field_type) ) { + is_valid = false; + return TC_ABORTALL; + } + } break; + + case EXPR_APPEND_TO: + // This doesn't directly change any identifiers, but does + // alter an aggregate. + if ( CheckAggrMod(e->GetType()) ) + return TC_ABORTALL; + break; + + case EXPR_CALL: + if ( CheckCall(e->AsCallExpr()) ) + return TC_ABORTALL; + break; + + case EXPR_TABLE_CONSTRUCTOR: + // These have EXPR_ASSIGN's in them that don't + // correspond to actual assignments to variables, + // so we don't want to traverse them. + return TC_ABORTSTMT; + + case EXPR_RECORD_COERCE: + case EXPR_RECORD_CONSTRUCTOR: + // Note, record coercion behaves like constructors in terms of + // potentially executing &default functions. In either case, + // the type of the expression reflects the type we want to analyze + // for side effects. + if ( CheckRecordConstructor(e->GetType()) ) + return TC_ABORTALL; + break; + + case EXPR_INDEX: + case EXPR_FIELD: { + // We treat these together because they both have to be checked + // when inside an "add" or "delete" statement. + auto aggr = e->GetOp1(); + auto aggr_t = aggr->GetType(); + + if ( in_aggr_mod_stmt ) { + auto aggr_id = aggr->AsNameExpr()->Id(); + + if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) ) + return TC_ABORTALL; + } + + else if ( t == EXPR_INDEX && aggr_t->Tag() == TYPE_TABLE ) { + if ( CheckTableRef(aggr_t) ) + return TC_ABORTALL; + } + } break; + + default: break; + } + + return TC_CONTINUE; +} + +bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) { + for ( auto i : ids ) { + if ( ignore_orig && i == ids.front() ) + continue; + + if ( id == i ) + return Invalid(); // reassignment + } + + return false; +} + +bool CSE_ValidityChecker::CheckAggrMod(const TypePtr& t) { + if ( ! IsAggr(t) ) + return false; + + for ( auto i : ids ) + if ( same_type(t, i->GetType()) ) + return Invalid(); + + return false; +} + +bool CSE_ValidityChecker::CheckRecordConstructor(const TypePtr& t) { + if ( t->Tag() != TYPE_RECORD ) + return false; + + return CheckSideEffects(SideEffectsOp::CONSTRUCTION, t); +} + +bool CSE_ValidityChecker::CheckTableMod(const TypePtr& t) { + if ( CheckAggrMod(t) ) + return true; + + if ( t->Tag() != TYPE_TABLE ) + return false; + + return CheckSideEffects(SideEffectsOp::WRITE, t); +} + +bool CSE_ValidityChecker::CheckTableRef(const TypePtr& t) { return CheckSideEffects(SideEffectsOp::READ, t); } + +bool CSE_ValidityChecker::CheckCall(const CallExpr* c) { + auto func = c->Func(); + std::string desc; + if ( func->Tag() != EXPR_NAME ) + // Can't analyze indirect calls. + return Invalid(); + + IDSet non_local_ids; + TypeSet aggrs; + bool is_unknown = false; + + auto resolved = pfs.GetCallSideEffects(func->AsNameExpr(), non_local_ids, aggrs, is_unknown); + ASSERT(resolved); + + if ( is_unknown || CheckSideEffects(non_local_ids, aggrs) ) + return Invalid(); + + return false; +} + +bool CSE_ValidityChecker::CheckSideEffects(SideEffectsOp::AccessType access, const TypePtr& t) { + IDSet non_local_ids; + TypeSet aggrs; + + if ( pfs.GetSideEffects(access, t.get(), non_local_ids, aggrs) ) + return Invalid(); + + return CheckSideEffects(non_local_ids, aggrs); +} + +bool CSE_ValidityChecker::CheckSideEffects(const IDSet& non_local_ids, const TypeSet& aggrs) { + if ( non_local_ids.empty() && aggrs.empty() ) + // This is far and away the most common case. + return false; + + for ( auto i : ids ) { + for ( auto nli : non_local_ids ) + if ( nli == i ) + return Invalid(); + + auto i_t = i->GetType(); + for ( auto a : aggrs ) + if ( same_type(a, i_t.get()) ) + return Invalid(); + } + + return false; +} + +} // namespace zeek::detail diff --git a/src/script_opt/CSE.h b/src/script_opt/CSE.h new file mode 100644 index 0000000000..40d600b3ae --- /dev/null +++ b/src/script_opt/CSE.h @@ -0,0 +1,116 @@ +// See the file "COPYING" in the main distribution directory for copyright. + +#pragma once + +#include "zeek/script_opt/ProfileFunc.h" + +namespace zeek::detail { + +class TempVar; + +// Helper class that walks an AST to determine whether it's safe to +// substitute a common subexpression (which at this point is an assignment +// to a variable) created using the assignment expression at position "start_e", +// at the location specified by the expression at position "end_e". +// +// See Reducer::ExprValid for a discussion of what's required for safety. + +class CSE_ValidityChecker : public TraversalCallback { +public: + CSE_ValidityChecker(ProfileFuncs& pfs, const std::vector& ids, const Expr* start_e, + const Expr* end_e); + + TraversalCode PreStmt(const Stmt*) override; + TraversalCode PostStmt(const Stmt*) override; + TraversalCode PreExpr(const Expr*) override; + + // Returns the ultimate verdict re safety. + bool IsValid() const { + if ( ! is_valid ) + return false; + + if ( ! have_end_e ) + reporter->InternalError("CSE_ValidityChecker: saw start but not end"); + return true; + } + +protected: + // Returns true if an assignment involving the given identifier on + // the LHS is in conflict with the identifiers we're tracking. + bool CheckID(const ID* id, bool ignore_orig); + + // Returns true if a modification to an aggregate of the given type + // potentially aliases with one of the identifiers we're tracking. + bool CheckAggrMod(const TypePtr& t); + + // Returns true if a record constructor/coercion of the given type has + // side effects and invalides the CSE opportunity. + bool CheckRecordConstructor(const TypePtr& t); + + // The same for modifications to tables. + bool CheckTableMod(const TypePtr& t); + + // The same for accessing (reading) tables. + bool CheckTableRef(const TypePtr& t); + + // The same for the given function call. + bool CheckCall(const CallExpr* c); + + // True if the given form of access to the given type has side effects. + bool CheckSideEffects(SideEffectsOp::AccessType access, const TypePtr& t); + + // True if side effects to the given identifiers and aggregates invalidate + // the CSE opportunity. + bool CheckSideEffects(const IDSet& non_local_ids, const TypeSet& aggrs); + + // Helper function that marks the CSE opportunity as invalid and returns + // "true" (used by various methods to signal invalidation). + bool Invalid() { + is_valid = false; + return true; + } + + // Profile across all script functions. + ProfileFuncs& pfs; + + // The list of identifiers for which an assignment to one of them + // renders the CSE unsafe. + const std::vector& ids; + + // Where in the AST to start our analysis. This is the initial + // assignment expression. + const Expr* start_e; + + // Where in the AST to end our analysis. + const Expr* end_e; + + // If what we're analyzing is a record element, then its offset. + // -1 if not. + int field; + + // The type of that record element, if any. + TypePtr field_type; + + // The verdict so far. + bool is_valid = true; + + // Whether we've encountered the start/end expression in + // the AST traversal. + bool have_start_e = false; + bool have_end_e = false; + + // Whether analyzed expressions occur in the context of a statement + // that modifies an aggregate ("add" or "delete"), which changes the + // interpretation of the expressions. + bool in_aggr_mod_stmt = false; +}; + +// Used for debugging, to communicate which expression wasn't +// reduced when we expected them all to be. +extern const Expr* non_reduced_perp; +extern bool checking_reduction; + +// Used to report a non-reduced expression. +extern bool NonReduced(const Expr* perp); + +} // namespace zeek::detail diff --git a/src/script_opt/Reduce.cc b/src/script_opt/Reduce.cc index 3a0fb8ba0d..fb1478e70e 100644 --- a/src/script_opt/Reduce.cc +++ b/src/script_opt/Reduce.cc @@ -825,265 +825,6 @@ std::shared_ptr Reducer::FindTemporary(const ID* id) const { return tmp->second; } -CSE_ValidityChecker::CSE_ValidityChecker(ProfileFuncs& _pfs, const std::vector& _ids, const Expr* _start_e, - const Expr* _end_e) - : pfs(_pfs), ids(_ids) { - start_e = _start_e; - end_e = _end_e; - - // Track whether this is a record assignment, in which case - // we're attuned to assignments to the same field for the - // same type of record. - if ( start_e->Tag() == EXPR_FIELD ) { - field = start_e->AsFieldExpr()->Field(); - - // Track the type of the record, too, so we don't confuse - // field references to different records that happen to - // have the same offset as potential aliases. - field_type = start_e->GetOp1()->GetType(); - } - - else - field = -1; // flags that there's no relevant field -} - -TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) { - auto t = s->Tag(); - - if ( t == STMT_WHEN ) { - // These are too hard to analyze - they result in lambda calls - // that can affect aggregates, etc. - is_valid = false; - return TC_ABORTALL; - } - - if ( t == STMT_ADD || t == STMT_DELETE ) - in_aggr_mod_stmt = true; - - return TC_CONTINUE; -} - -TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) { - if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE ) - in_aggr_mod_stmt = false; - - return TC_CONTINUE; -} - -TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) { - if ( e == start_e ) { - ASSERT(! have_start_e); - have_start_e = true; - - // Don't analyze the expression, as it's our starting - // point and we don't want to conflate its properties - // with those of any intervening expressions. - return TC_CONTINUE; - } - - if ( e == end_e ) { - if ( ! have_start_e ) - reporter->InternalError("CSE_ValidityChecker: saw end but not start"); - - ASSERT(! have_end_e); - have_end_e = true; - - // ... and we're now done. - return TC_ABORTALL; - } - - if ( ! have_start_e ) - // We don't yet have a starting point. - return TC_CONTINUE; - - // We have a starting point, and not yet an ending point. - auto t = e->Tag(); - - switch ( t ) { - case EXPR_ASSIGN: { - auto lhs_ref = e->GetOp1()->AsRefExprPtr(); - auto lhs = lhs_ref->GetOp1()->AsNameExpr(); - - if ( CheckID(lhs->Id(), false) ) - return TC_ABORTALL; - - // Note, we don't use CheckAggrMod() because this is a plain - // assignment. It might be changing a variable's binding to - // an aggregate ("aggr_var = new_aggr_val"), but we don't - // introduce temporaries that are simply aliases of existing - // variables (e.g., we don't have "::#8 = aggr_var"), - // and so there's no concern that the temporary could now be - // referring to the wrong aggregate. If instead we have - // "::#8 = aggr_var$foo", then a reassignment here - // to "aggr_var" will already be caught by CheckID(). - } break; - - case EXPR_INDEX_ASSIGN: { - auto lhs_aggr = e->GetOp1(); - auto lhs_aggr_id = lhs_aggr->AsNameExpr()->Id(); - - if ( CheckID(lhs_aggr_id, true) || CheckTableMod(lhs_aggr->GetType()) ) - return TC_ABORTALL; - } break; - - case EXPR_FIELD_LHS_ASSIGN: { - auto lhs = e->GetOp1(); - auto lhs_aggr_id = lhs->AsNameExpr()->Id(); - auto lhs_field = e->AsFieldLHSAssignExpr()->Field(); - - if ( CheckID(lhs_aggr_id, true) ) - return TC_ABORTALL; - if ( lhs_field == field && same_type(lhs_aggr_id->GetType(), field_type) ) { - is_valid = false; - return TC_ABORTALL; - } - } break; - - case EXPR_APPEND_TO: - // This doesn't directly change any identifiers, but does - // alter an aggregate. - if ( CheckAggrMod(e->GetType()) ) - return TC_ABORTALL; - break; - - case EXPR_CALL: - if ( CheckCall(e->AsCallExpr()) ) - return TC_ABORTALL; - break; - - case EXPR_TABLE_CONSTRUCTOR: - // These have EXPR_ASSIGN's in them that don't - // correspond to actual assignments to variables, - // so we don't want to traverse them. - return TC_ABORTSTMT; - - case EXPR_RECORD_COERCE: - case EXPR_RECORD_CONSTRUCTOR: - // Note, record coercion behaves like constructors in terms of - // potentially executing &default functions. In either case, - // the type of the expression reflects the type we want to analyze - // for side effects. - if ( CheckRecordConstructor(e->GetType()) ) - return TC_ABORTALL; - break; - - case EXPR_INDEX: - case EXPR_FIELD: { - // We treat these together because they both have to be checked - // when inside an "add" or "delete" statement. - auto aggr = e->GetOp1(); - auto aggr_t = aggr->GetType(); - - if ( in_aggr_mod_stmt ) { - auto aggr_id = aggr->AsNameExpr()->Id(); - - if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) ) - return TC_ABORTALL; - } - - else if ( t == EXPR_INDEX && aggr_t->Tag() == TYPE_TABLE ) { - if ( CheckTableRef(aggr_t) ) - return TC_ABORTALL; - } - } break; - - default: break; - } - - return TC_CONTINUE; -} - -bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) { - for ( auto i : ids ) { - if ( ignore_orig && i == ids.front() ) - continue; - - if ( id == i ) - return Invalid(); // reassignment - } - - return false; -} - -bool CSE_ValidityChecker::CheckAggrMod(const TypePtr& t) { - if ( ! IsAggr(t) ) - return false; - - for ( auto i : ids ) - if ( same_type(t, i->GetType()) ) - return Invalid(); - - return false; -} - -bool CSE_ValidityChecker::CheckRecordConstructor(const TypePtr& t) { - if ( t->Tag() != TYPE_RECORD ) - return false; - - return CheckSideEffects(SideEffectsOp::CONSTRUCTION, t); -} - -bool CSE_ValidityChecker::CheckTableMod(const TypePtr& t) { - if ( CheckAggrMod(t) ) - return true; - - if ( t->Tag() != TYPE_TABLE ) - return false; - - return CheckSideEffects(SideEffectsOp::WRITE, t); -} - -bool CSE_ValidityChecker::CheckTableRef(const TypePtr& t) { return CheckSideEffects(SideEffectsOp::READ, t); } - -bool CSE_ValidityChecker::CheckCall(const CallExpr* c) { - auto func = c->Func(); - std::string desc; - if ( func->Tag() != EXPR_NAME ) - // Can't analyze indirect calls. - return Invalid(); - - IDSet non_local_ids; - TypeSet aggrs; - bool is_unknown = false; - - auto resolved = pfs.GetCallSideEffects(func->AsNameExpr(), non_local_ids, aggrs, is_unknown); - ASSERT(resolved); - - if ( is_unknown || CheckSideEffects(non_local_ids, aggrs) ) - return Invalid(); - - return false; -} - -bool CSE_ValidityChecker::CheckSideEffects(SideEffectsOp::AccessType access, const TypePtr& t) { - IDSet non_local_ids; - TypeSet aggrs; - - if ( pfs.GetSideEffects(access, t.get(), non_local_ids, aggrs) ) - return Invalid(); - - return CheckSideEffects(non_local_ids, aggrs); -} - -bool CSE_ValidityChecker::CheckSideEffects(const IDSet& non_local_ids, const TypeSet& aggrs) { - if ( non_local_ids.empty() && aggrs.empty() ) - // This is far and away the most common case. - return false; - - for ( auto i : ids ) { - for ( auto nli : non_local_ids ) - if ( nli == i ) - return Invalid(); - - auto i_t = i->GetType(); - for ( auto a : aggrs ) - if ( same_type(a, i_t.get()) ) - return Invalid(); - } - - return false; -} - const Expr* non_reduced_perp; bool checking_reduction; diff --git a/src/script_opt/Reduce.h b/src/script_opt/Reduce.h index b4f4cb4743..706ca1f7e8 100644 --- a/src/script_opt/Reduce.h +++ b/src/script_opt/Reduce.h @@ -6,6 +6,7 @@ #include "zeek/Scope.h" #include "zeek/Stmt.h" #include "zeek/Traverse.h" +#include "zeek/script_opt/CSE.h" #include "zeek/script_opt/ObjMgr.h" #include "zeek/script_opt/ProfileFunc.h" @@ -314,104 +315,6 @@ protected: bool opt_ready = false; }; -// Helper class that walks an AST to determine whether it's safe -// to substitute a common subexpression (which at this point is -// an assignment to a variable) created using the assignment -// expression at position "start_e", at the location specified by -// the expression at position "end_e". -// -// See Reducer::ExprValid for a discussion of what's required -// for safety. - -class CSE_ValidityChecker : public TraversalCallback { -public: - CSE_ValidityChecker(ProfileFuncs& pfs, const std::vector& ids, const Expr* start_e, const Expr* end_e); - - TraversalCode PreStmt(const Stmt*) override; - TraversalCode PostStmt(const Stmt*) override; - TraversalCode PreExpr(const Expr*) override; - - // Returns the ultimate verdict re safety. - bool IsValid() const { - if ( ! is_valid ) - return false; - - if ( ! have_end_e ) - reporter->InternalError("CSE_ValidityChecker: saw start but not end"); - return true; - } - -protected: - // Returns true if an assignment involving the given identifier on - // the LHS is in conflict with the identifiers we're tracking. - bool CheckID(const ID* id, bool ignore_orig); - - // Returns true if a modification to an aggregate of the given type - // potentially aliases with one of the identifiers we're tracking. - bool CheckAggrMod(const TypePtr& t); - - // Returns true if a record constructor/coercion of the given type has - // side effects and invalides the CSE opportunity. - bool CheckRecordConstructor(const TypePtr& t); - - // The same for modifications to tables. - bool CheckTableMod(const TypePtr& t); - - // The same for accessing (reading) tables. - bool CheckTableRef(const TypePtr& t); - - // The same for the given function call. - bool CheckCall(const CallExpr* c); - - // True if the given form of access to the given type has side effects. - bool CheckSideEffects(SideEffectsOp::AccessType access, const TypePtr& t); - - // True if side effects to the given identifiers and aggregates invalidate - // the CSE opportunity. - bool CheckSideEffects(const IDSet& non_local_ids, const TypeSet& aggrs); - - // Helper function that marks the CSE opportunity as invalid and returns - // "true" (used by various methods to signal invalidation). - bool Invalid() { - is_valid = false; - return true; - } - - // Profile across all script functions. - ProfileFuncs& pfs; - - // The list of identifiers for which an assignment to one of them - // renders the CSE unsafe. - const std::vector& ids; - - // Where in the AST to start our analysis. This is the initial - // assignment expression. - const Expr* start_e; - - // Where in the AST to end our analysis. - const Expr* end_e; - - // If what we're analyzing is a record element, then its offset. - // -1 if not. - int field; - - // The type of that record element, if any. - TypePtr field_type; - - // The verdict so far. - bool is_valid = true; - - // Whether we've encountered the start/end expression in - // the AST traversal. - bool have_start_e = false; - bool have_end_e = false; - - // Whether analyzed expressions occur in the context of a statement - // that modifies an aggregate ("add" or "delete"), which changes the - // interpretation of the expressions. - bool in_aggr_mod_stmt = false; -}; - // Used for debugging, to communicate which expression wasn't // reduced when we expected them all to be. extern const Expr* non_reduced_perp;