From 34ccd3e417c58159ea7e0f6a36f4899f1f5f8583 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Sat, 27 Feb 2021 11:35:01 -0800 Subject: [PATCH] helper class checking if common-subexpression elimination opportunity is valid --- src/script_opt/Reduce.cc | 199 +++++++++++++++++++++++++++++++++++++++ src/script_opt/Reduce.h | 74 +++++++++++++++ 2 files changed, 273 insertions(+) diff --git a/src/script_opt/Reduce.cc b/src/script_opt/Reduce.cc index 3346f84197..6ff3e9acad 100644 --- a/src/script_opt/Reduce.cc +++ b/src/script_opt/Reduce.cc @@ -253,6 +253,205 @@ void Reducer::TrackExprReplacement(const Expr* orig, const Expr* e) } +CSE_ValidityChecker::CSE_ValidityChecker(const std::vector& _ids, + const Expr* _start_e, const Expr* _end_e) +: ids(_ids) + { + start_e = _start_e; + end_e = _end_e; + + // Track whether this is a record assignment, in which case + // we're attuned to assignments to the same field for the + // same type of record. + if ( start_e->Tag() == EXPR_FIELD ) + { + field = start_e->AsFieldExpr()->Field(); + + // Track the type of the record, too, so we don't confuse + // field references to different records that happen to + // have the same offset as potential aliases. + field_type = start_e->GetOp1()->GetType(); + } + + else + field = -1; // flags that there's no relevant field + } + +TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) + { + if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE ) + in_aggr_mod_stmt = true; + + return TC_CONTINUE; + } + +TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) + { + if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE ) + in_aggr_mod_stmt = false; + + return TC_CONTINUE; + } + +TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) + { + if ( e == start_e ) + { + ASSERT(! have_start_e); + have_start_e = true; + + // Don't analyze the expression, as it's our starting + // point and we don't want to conflate its properties + // with those of any intervening expression. + return TC_CONTINUE; + } + + if ( e == end_e ) + { + if ( ! have_start_e ) + reporter->InternalError("CSE_ValidityChecker: saw end but not start"); + + ASSERT(! have_end_e); + have_end_e = true; + + // ... and we're now done. + return TC_ABORTALL; + } + + if ( ! have_start_e ) + // We don't yet have a starting point. + return TC_CONTINUE; + + // We have a starting point, and not yet an ending point. + auto t = e->Tag(); + + switch ( t ) { + case EXPR_ASSIGN: + { + auto lhs_ref = e->GetOp1()->AsRefExprPtr(); + auto lhs = lhs_ref->GetOp1()->AsNameExpr(); + + if ( CheckID(ids, lhs->Id(), false) ) + { + is_valid = false; + return TC_ABORTALL; + } + + // Note, we don't use CheckAggrMod() because this + // is a plain assignment. It might be changing a variable's + // binding to an aggregate, but it's not changing the + // aggregate itself. + } + break; + + case EXPR_INDEX_ASSIGN: + { + auto lhs_aggr = e->GetOp1(); + auto lhs_aggr_id = lhs_aggr->AsNameExpr()->Id(); + + if ( CheckID(ids, lhs_aggr_id, true) || CheckAggrMod(ids, e) ) + { + is_valid = false; + return TC_ABORTALL; + } + } + break; + + case EXPR_FIELD_LHS_ASSIGN: + { + auto lhs = e->GetOp1(); + auto lhs_aggr_id = lhs->AsNameExpr()->Id(); + auto lhs_field = e->AsFieldLHSAssignExpr()->Field(); + + if ( lhs_field == field && + same_type(lhs_aggr_id->GetType(), field_type) ) + { + // Potential assignment to the same field as for + // our expression of interest. Even if the + // identifier involved is not one we have our eye + // on, due to aggregate aliasing this could be + // altering the value of our expression, so bail. + is_valid = false; + return TC_ABORTALL; + } + + if ( CheckID(ids, lhs_aggr_id, true) || CheckAggrMod(ids, e) ) + { + is_valid = false; + return TC_ABORTALL; + } + } + break; + + case EXPR_CALL: + { + for ( auto i : ids ) + if ( i->IsGlobal() || IsAggr(i->GetType()) ) + { + is_valid = false; + return TC_ABORTALL; + } + } + break; + + default: + if ( in_aggr_mod_stmt && (t == EXPR_INDEX || t == EXPR_FIELD) ) + { + auto aggr = e->GetOp1(); + auto aggr_id = aggr->AsNameExpr()->Id(); + + if ( CheckID(ids, aggr_id, true) ) + { + is_valid = false; + return TC_ABORTALL; + } + } + + break; + } + + return TC_CONTINUE; + } + +bool CSE_ValidityChecker::CheckID(const std::vector& ids, + const ID* id, bool ignore_orig) const + { + // Only check type info for aggregates. + auto id_t = IsAggr(id->GetType()) ? id->GetType() : nullptr; + + for ( auto i : ids ) + { + if ( ignore_orig && i == ids.front() ) + continue; + + if ( id == i ) + return true; // reassignment + + if ( id_t && same_type(id_t, i->GetType()) ) + // Same-type aggregate. + return true; + } + + return false; + } + +bool CSE_ValidityChecker::CheckAggrMod(const std::vector& ids, + const Expr* e) const + { + auto e_i_t = e->GetType(); + if ( IsAggr(e_i_t) ) + { + // This assignment sets an aggregate value. + // Look for type matches. + for ( auto i : ids ) + if ( same_type(e_i_t, i->GetType()) ) + return true; + } + + return false; + } + + bool same_DPs(const DefPoints* dp1, const DefPoints* dp2) { if ( dp1 == dp2 ) diff --git a/src/script_opt/Reduce.h b/src/script_opt/Reduce.h index 3adab01f34..da4f0c318d 100644 --- a/src/script_opt/Reduce.h +++ b/src/script_opt/Reduce.h @@ -5,6 +5,7 @@ #include "zeek/Scope.h" #include "zeek/Expr.h" #include "zeek/Stmt.h" +#include "zeek/Traverse.h" #include "zeek/script_opt/DefSetsMgr.h" namespace zeek::detail { @@ -214,6 +215,79 @@ protected: const DefSetsMgr* mgr = nullptr; }; + +// Helper class that walks an AST to determine whether it's safe +// to substitute a common subexpression (which at this point is +// an assignment to a variable) created using the assignment +// expression at position "start_e", at the location specified by +// the expression at position "end_e". +// +// See Reducer::ExprValid for a discussion of what's required +// for safety. + +class CSE_ValidityChecker : public TraversalCallback { +public: + CSE_ValidityChecker(const std::vector& ids, + const Expr* start_e, const Expr* end_e); + + TraversalCode PreStmt(const Stmt*) override; + TraversalCode PostStmt(const Stmt*) override; + TraversalCode PreExpr(const Expr*) override; + + // Returns the ultimate verdict re safety. + bool IsValid() const + { + if ( ! is_valid ) + return false; + + if ( ! have_end_e ) + reporter->InternalError("CSE_ValidityChecker: saw start but not end"); + return true; + } + +protected: + // Returns true if an assigment involving the given identifier on + // the LHS is in conflict with the given list of identifiers. + bool CheckID(const std::vector& ids, const ID* id, + bool ignore_orig) const; + + // Returns true if the assignment given by 'e' modifies an aggregate + // with the same type as that of one of the identifiers. + bool CheckAggrMod(const std::vector& ids, + const Expr* e) const; + + // The list of identifiers for which an assignment to one of them + // renders the CSE unsafe. + const std::vector& ids; + + // Where in the AST to start our analysis. This is the initial + // assignment expression. + const Expr* start_e; + + // Where in the AST to end our analysis. + const Expr* end_e; + + // If what we're analyzing is a record element, then its offset. + // -1 if not. + int field; + + // The type of that record element, if any. + TypePtr field_type; + + // The verdict so far. + bool is_valid = true; + + // Whether we've encountered the start/end expression in + // the AST traversal. + bool have_start_e = false; + bool have_end_e = false; + + // Whether analyzed expressions occur in the context of + // a statement that modifies an aggregate ("add" or "delete"). + bool in_aggr_mod_stmt = false; +}; + + extern bool same_DPs(const DefPoints* dp1, const DefPoints* dp2); // Used for debugging, to communicate which expression wasn't