reworked AST optimizers analysis of side effects during aggregate operations & calls

This commit is contained in:
Vern Paxson 2023-12-04 16:55:38 -08:00
parent c028901146
commit 740a087765
13 changed files with 1119 additions and 223 deletions

View file

@ -11,12 +11,14 @@
#include "zeek/Stmt.h"
#include "zeek/Var.h"
#include "zeek/script_opt/ExprOptInfo.h"
#include "zeek/script_opt/FuncInfo.h"
#include "zeek/script_opt/StmtOptInfo.h"
#include "zeek/script_opt/TempVar.h"
namespace zeek::detail {
Reducer::Reducer(const ScriptFunc* func, std::shared_ptr<ProfileFunc> _pf) : pf(std::move(_pf)) {
Reducer::Reducer(const ScriptFunc* func, std::shared_ptr<ProfileFunc> _pf, ProfileFuncs& _pfs)
: pf(std::move(_pf)), pfs(_pfs) {
auto& ft = func->GetType();
// Track the parameters so we don't remap them.
@ -424,6 +426,41 @@ IDPtr Reducer::FindExprTmp(const Expr* rhs, const Expr* a, const std::shared_ptr
}
bool Reducer::ExprValid(const ID* id, const Expr* e1, const Expr* e2) const {
// First check for whether e1 is already known to itself have side effects.
// If so, then it's never safe to reuse its associated identifier in lieu
// of e2.
std::optional<ExprSideEffects>& e1_se = e1->GetOptInfo()->SideEffects();
if ( ! e1_se ) {
bool has_side_effects = false;
auto e1_t = e1->GetType();
if ( e1_t->Tag() == TYPE_OPAQUE || e1_t->Tag() == TYPE_ANY )
// These have difficult-to-analyze semantics.
has_side_effects = true;
else if ( e1->Tag() == EXPR_INDEX ) {
auto aggr = e1->GetOp1();
auto aggr_t = aggr->GetType();
if ( pfs.HasSideEffects(SideEffectsOp::READ, aggr_t) )
has_side_effects = true;
else if ( aggr_t->Tag() == TYPE_TABLE && pfs.IsTableWithDefaultAggr(aggr_t.get()) )
has_side_effects = true;
}
else if ( e1->Tag() == EXPR_RECORD_CONSTRUCTOR || e1->Tag() == EXPR_RECORD_COERCE )
has_side_effects = pfs.HasSideEffects(SideEffectsOp::CONSTRUCTION, e1->GetType());
e1_se = ExprSideEffects(has_side_effects);
}
if ( e1_se->HasSideEffects() ) {
// We already know that e2 is structurally identical to e1.
e2->GetOptInfo()->SideEffects() = ExprSideEffects(true);
return false;
}
// Here are the considerations for expression validity.
//
// * None of the operands used in the given expression can
@ -437,11 +474,14 @@ bool Reducer::ExprValid(const ID* id, const Expr* e1, const Expr* e2) const {
// * Same goes to modifications of aggregates via "add" or "delete"
// or "+=" append.
//
// * No propagation of expressions that are based on aggregates
// across function calls.
// * Assessment of any record constructors or coercions, or
// table references or modifications, for possible invocation of
// associated handlers that have side effects.
//
// * No propagation of expressions that are based on globals
// across calls.
// * Assessment of function calls for potential side effects.
//
// These latter two are guided by the global profile of the full set
// of script functions.
// Tracks which ID's are germane for our analysis.
std::vector<const ID*> ids;
@ -456,7 +496,7 @@ bool Reducer::ExprValid(const ID* id, const Expr* e1, const Expr* e2) const {
if ( e1->Tag() == EXPR_NAME )
ids.push_back(e1->AsNameExpr()->Id());
CSE_ValidityChecker vc(ids, e1, e2);
CSE_ValidityChecker vc(pfs, ids, e1, e2);
reduction_root->Traverse(&vc);
return vc.IsValid();
@ -785,15 +825,12 @@ std::shared_ptr<TempVar> Reducer::FindTemporary(const ID* id) const {
return tmp->second;
}
CSE_ValidityChecker::CSE_ValidityChecker(const std::vector<const ID*>& _ids, const Expr* _start_e, const Expr* _end_e)
: ids(_ids) {
CSE_ValidityChecker::CSE_ValidityChecker(ProfileFuncs& _pfs, const std::vector<const ID*>& _ids, const Expr* _start_e,
const Expr* _end_e)
: pfs(_pfs), ids(_ids) {
start_e = _start_e;
end_e = _end_e;
for ( auto i : ids )
if ( i->IsGlobal() || IsAggr(i->GetType()) )
sensitive_to_calls = true;
// Track whether this is a record assignment, in which case
// we're attuned to assignments to the same field for the
// same type of record.
@ -811,7 +848,16 @@ CSE_ValidityChecker::CSE_ValidityChecker(const std::vector<const ID*>& _ids, con
}
TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) {
if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE )
auto t = s->Tag();
if ( t == STMT_WHEN ) {
// These are too hard to analyze - they result in lambda calls
// that can affect aggregates, etc.
is_valid = false;
return TC_ABORTALL;
}
if ( t == STMT_ADD || t == STMT_DELETE )
in_aggr_mod_stmt = true;
return TC_CONTINUE;
@ -831,7 +877,7 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
// Don't analyze the expression, as it's our starting
// point and we don't want to conflate its properties
// with those of any intervening expression.
// with those of any intervening expressions.
return TC_CONTINUE;
}
@ -858,25 +904,26 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
auto lhs_ref = e->GetOp1()->AsRefExprPtr();
auto lhs = lhs_ref->GetOp1()->AsNameExpr();
if ( CheckID(ids, lhs->Id(), false) ) {
is_valid = false;
if ( CheckID(lhs->Id(), false) )
return TC_ABORTALL;
}
// Note, we don't use CheckAggrMod() because this
// is a plain assignment. It might be changing a variable's
// binding to an aggregate, but it's not changing the
// aggregate itself.
// Note, we don't use CheckAggrMod() because this is a plain
// assignment. It might be changing a variable's binding to
// an aggregate ("aggr_var = new_aggr_val"), but we don't
// introduce temporaries that are simply aliases of existing
// variables (e.g., we don't have "<internal>::#8 = aggr_var"),
// and so there's no concern that the temporary could now be
// referring to the wrong aggregate. If instead we have
// "<internal>::#8 = aggr_var$foo", then a reassignment here
// to "aggr_var" will already be caught by CheckID().
} break;
case EXPR_INDEX_ASSIGN: {
auto lhs_aggr = e->GetOp1();
auto lhs_aggr_id = lhs_aggr->AsNameExpr()->Id();
if ( CheckID(ids, lhs_aggr_id, true) || CheckAggrMod(ids, e) ) {
is_valid = false;
if ( CheckID(lhs_aggr_id, true) || CheckTableMod(lhs_aggr->GetType()) )
return TC_ABORTALL;
}
} break;
case EXPR_FIELD_LHS_ASSIGN: {
@ -884,17 +931,9 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
auto lhs_aggr_id = lhs->AsNameExpr()->Id();
auto lhs_field = e->AsFieldLHSAssignExpr()->Field();
if ( lhs_field == field && same_type(lhs_aggr_id->GetType(), field_type) ) {
// Potential assignment to the same field as for
// our expression of interest. Even if the
// identifier involved is not one we have our eye
// on, due to aggregate aliasing this could be
// altering the value of our expression, so bail.
is_valid = false;
if ( CheckID(lhs_aggr_id, true) )
return TC_ABORTALL;
}
if ( CheckID(ids, lhs_aggr_id, true) || CheckAggrMod(ids, e) ) {
if ( lhs_field == field && same_type(lhs_aggr_id->GetType(), field_type) ) {
is_valid = false;
return TC_ABORTALL;
}
@ -903,17 +942,13 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
case EXPR_APPEND_TO:
// This doesn't directly change any identifiers, but does
// alter an aggregate.
if ( CheckAggrMod(ids, e) ) {
is_valid = false;
if ( CheckAggrMod(e->GetType()) )
return TC_ABORTALL;
}
break;
case EXPR_CALL:
if ( sensitive_to_calls ) {
is_valid = false;
if ( CheckCall(e->AsCallExpr()) )
return TC_ABORTALL;
}
break;
case EXPR_TABLE_CONSTRUCTOR:
@ -922,49 +957,35 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
// so we don't want to traverse them.
return TC_ABORTSTMT;
case EXPR_RECORD_COERCE:
case EXPR_RECORD_CONSTRUCTOR:
// If these have initializations done at construction
// time, those can include function calls.
if ( sensitive_to_calls ) {
auto& et = e->GetType();
if ( et->Tag() == TYPE_RECORD && ! et->AsRecordType()->IdempotentCreation() ) {
is_valid = false;
return TC_ABORTALL;
}
}
// Note, record coercion behaves like constructors in terms of
// potentially executing &default functions. In either case,
// the type of the expression reflects the type we want to analyze
// for side effects.
if ( CheckRecordConstructor(e->GetType()) )
return TC_ABORTALL;
break;
case EXPR_INDEX:
case EXPR_FIELD:
// We treat these together because they both have
// to be checked when inside an "add" or "delete"
// statement.
case EXPR_FIELD: {
// We treat these together because they both have to be checked
// when inside an "add" or "delete" statement.
auto aggr = e->GetOp1();
auto aggr_t = aggr->GetType();
if ( in_aggr_mod_stmt ) {
auto aggr = e->GetOp1();
auto aggr_id = aggr->AsNameExpr()->Id();
if ( CheckID(ids, aggr_id, true) ) {
is_valid = false;
if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) )
return TC_ABORTALL;
}
}
if ( t == EXPR_INDEX && sensitive_to_calls ) {
// Unfortunately in isolation we can't
// statically determine whether this table
// has a &default associated with it. In
// principle we could track all instances
// of the table type seen (across the
// entire set of scripts), and note whether
// any of those include an expression, but
// that's a lot of work for what might be
// minimal gain.
is_valid = false;
return TC_ABORTALL;
else if ( t == EXPR_INDEX && aggr_t->Tag() == TYPE_TABLE ) {
if ( CheckTableRef(aggr_t) )
return TC_ABORTALL;
}
break;
} break;
default: break;
}
@ -972,33 +993,92 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
return TC_CONTINUE;
}
bool CSE_ValidityChecker::CheckID(const std::vector<const ID*>& ids, const ID* id, bool ignore_orig) const {
// Only check type info for aggregates.
auto id_t = IsAggr(id->GetType()) ? id->GetType() : nullptr;
bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) {
for ( auto i : ids ) {
if ( ignore_orig && i == ids.front() )
continue;
if ( id == i )
return true; // reassignment
if ( id_t && same_type(id_t, i->GetType()) )
// Same-type aggregate.
return true;
return Invalid(); // reassignment
}
return false;
}
bool CSE_ValidityChecker::CheckAggrMod(const std::vector<const ID*>& ids, const Expr* e) const {
const auto& e_i_t = e->GetType();
if ( IsAggr(e_i_t) ) {
// This assignment sets an aggregate value.
// Look for type matches.
for ( auto i : ids )
if ( same_type(e_i_t, i->GetType()) )
return true;
bool CSE_ValidityChecker::CheckAggrMod(const TypePtr& t) {
if ( ! IsAggr(t) )
return false;
for ( auto i : ids )
if ( same_type(t, i->GetType()) )
return Invalid();
return false;
}
bool CSE_ValidityChecker::CheckRecordConstructor(const TypePtr& t) {
if ( t->Tag() != TYPE_RECORD )
return false;
return CheckSideEffects(SideEffectsOp::CONSTRUCTION, t);
}
bool CSE_ValidityChecker::CheckTableMod(const TypePtr& t) {
if ( CheckAggrMod(t) )
return true;
if ( t->Tag() != TYPE_TABLE )
return false;
return CheckSideEffects(SideEffectsOp::WRITE, t);
}
bool CSE_ValidityChecker::CheckTableRef(const TypePtr& t) { return CheckSideEffects(SideEffectsOp::READ, t); }
bool CSE_ValidityChecker::CheckCall(const CallExpr* c) {
auto func = c->Func();
std::string desc;
if ( func->Tag() != EXPR_NAME )
// Can't analyze indirect calls.
return Invalid();
IDSet non_local_ids;
TypeSet aggrs;
bool is_unknown = false;
auto resolved = pfs.GetCallSideEffects(func->AsNameExpr(), non_local_ids, aggrs, is_unknown);
ASSERT(resolved);
if ( is_unknown || CheckSideEffects(non_local_ids, aggrs) )
return Invalid();
return false;
}
bool CSE_ValidityChecker::CheckSideEffects(SideEffectsOp::AccessType access, const TypePtr& t) {
IDSet non_local_ids;
TypeSet aggrs;
if ( pfs.GetSideEffects(access, t.get(), non_local_ids, aggrs) )
return Invalid();
return CheckSideEffects(non_local_ids, aggrs);
}
bool CSE_ValidityChecker::CheckSideEffects(const IDSet& non_local_ids, const TypeSet& aggrs) {
if ( non_local_ids.empty() && aggrs.empty() )
// This is far and away the most common case.
return false;
for ( auto i : ids ) {
for ( auto nli : non_local_ids )
if ( nli == i )
return Invalid();
auto i_t = i->GetType();
for ( auto a : aggrs )
if ( same_type(a, i_t.get()) )
return Invalid();
}
return false;