// See the file "COPYING" in the main distribution directory for copyright. #include "zeek/script_opt/CSE.h" #include "zeek/script_opt/Expr.h" namespace zeek::detail { CSE_ValidityChecker::CSE_ValidityChecker(std::shared_ptr _pfs, const std::vector& _ids, const Expr* _start_e, const Expr* _end_e) : pfs(std::move(_pfs)), ids(_ids) { start_e = _start_e; end_e = _end_e; // For validity checking, if end_e is inside a loop and start_e is // outside that loop, then we need to extend the checking beyond end_e // to the end of the loop, to account for correctness after iterating // through the loop. We do that as follows. Upon entering an outer // loop, we set end_s to that loop. (We can tell it's an outer loop if, // upon entering, end_s is nil.) (1) If we encounter end_e while inside // that loop (which we can tell because end_s is non-nil), then we clear // end_e to signal that we're now using end_s to terminate the traversal. // (2) If we complete the loop without encountering end_e (which we can // tell because after traversal end_e is non-nil), then we clear end_s // to mark that the traversal is now not inside a loop. end_s = nullptr; // Track whether this is a record assignment, in which case // we're attuned to assignments to the same field for the // same type of record. if ( start_e->Tag() == EXPR_FIELD ) { field = start_e->AsFieldExpr()->Field(); // Track the type of the record, too, so we don't confuse // field references to different records that happen to // have the same offset as potential aliases. field_type = start_e->GetOp1()->GetType(); } else field = -1; // flags that there's no relevant field } TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) { auto t = s->Tag(); if ( t == STMT_WHEN ) { // These are too hard to analyze - they result in lambda calls // that can affect aggregates, etc. is_valid = false; return TC_ABORTALL; } if ( (t == STMT_WHILE || t == STMT_FOR) && have_start_e && ! end_s ) // We've started the traversal and are entering an outer loop. end_s = s; return TC_CONTINUE; } TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) { if ( end_s == s ) { if ( ! end_e ) // We've done the outer loop containing the end expression. return TC_ABORTALL; // We're no longer doing an outer loop. end_s = nullptr; } return TC_CONTINUE; } TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) { if ( e == start_e ) { if ( ! have_start_e ) { have_start_e = true; // Don't analyze the expression, as it's our starting // point and we don't want to conflate its properties // with those of any intervening expressions. return TC_CONTINUE; } } if ( e == end_e ) { if ( ! have_start_e ) reporter->InternalError("CSE_ValidityChecker: saw end but not start"); ASSERT(! have_end_e); have_end_e = true; if ( ! end_s ) // We're now done. return TC_ABORTALL; // Need to finish the loop before we mark things as done. // Signal to the statement traversal that we're in that state. end_e = nullptr; } if ( ! have_start_e ) // We don't yet have a starting point. return TC_CONTINUE; // We have a starting point, and not yet an ending point. auto t = e->Tag(); switch ( t ) { case EXPR_ASSIGN: { auto lhs_ref = e->GetOp1()->AsRefExprPtr(); auto lhs = lhs_ref->GetOp1()->AsNameExpr(); if ( CheckID(lhs->Id(), false) ) return TC_ABORTALL; // Note, we don't use CheckAggrMod() because this is a plain // assignment. It might be changing a variable's binding to // an aggregate ("aggr_var = new_aggr_val"), but we don't // introduce temporaries that are simply aliases of existing // variables (e.g., we don't have "::#8 = aggr_var"), // and so there's no concern that the temporary could now be // referring to the wrong aggregate. If instead we have // "::#8 = aggr_var$foo", then a reassignment here // to "aggr_var" will already be caught by CheckID(). } break; case EXPR_INDEX_ASSIGN: { auto lhs_aggr = e->GetOp1(); auto lhs_aggr_id = lhs_aggr->AsNameExpr()->Id(); if ( CheckID(lhs_aggr_id, true) || CheckTableMod(lhs_aggr->GetType()) ) return TC_ABORTALL; } break; case EXPR_FIELD_LHS_ASSIGN: { auto lhs = e->GetOp1(); auto lhs_aggr_id = lhs->AsNameExpr()->Id(); auto lhs_field = static_cast(e)->Field(); if ( CheckID(lhs_aggr_id, true) ) return TC_ABORTALL; if ( lhs_field == field && same_type(lhs_aggr_id->GetType(), field_type) ) { is_valid = false; return TC_ABORTALL; } } break; case EXPR_AGGR_ADD: case EXPR_AGGR_DEL: ++in_aggr_mod_expr; break; case EXPR_APPEND_TO: // This doesn't directly change any identifiers, but does // alter an aggregate. if ( CheckAggrMod(e->GetType()) ) return TC_ABORTALL; break; case EXPR_CALL: if ( CheckCall(e->AsCallExpr()) ) return TC_ABORTALL; break; case EXPR_TABLE_CONSTRUCTOR: // These have EXPR_ASSIGN's in them that don't // correspond to actual assignments to variables, // so we don't want to traverse them. return TC_ABORTSTMT; case EXPR_RECORD_COERCE: case EXPR_RECORD_CONSTRUCTOR: case EXPR_REC_CONSTRUCT_WITH_REC: // Note, record coercion behaves like constructors in terms of // potentially executing &default functions. In either case, // the type of the expression reflects the type we want to analyze // for side effects. if ( CheckRecordConstructor(e->GetType()) ) return TC_ABORTALL; break; case EXPR_INDEX: case EXPR_FIELD: { // We treat these together because they both have to be checked // when inside an "add" or "delete" statement. auto aggr = e->GetOp1(); auto aggr_t = aggr->GetType(); if ( in_aggr_mod_expr > 0 ) { auto aggr_id = aggr->AsNameExpr()->Id(); if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) ) return TC_ABORTALL; } else if ( t == EXPR_INDEX && aggr_t->Tag() == TYPE_TABLE ) { if ( CheckTableRef(aggr_t) ) return TC_ABORTALL; } } break; default: break; } return TC_CONTINUE; } TraversalCode CSE_ValidityChecker::PostExpr(const Expr* e) { if ( have_start_e && (e->Tag() == EXPR_AGGR_ADD || e->Tag() == EXPR_AGGR_DEL) ) --in_aggr_mod_expr; return TC_CONTINUE; } bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) { for ( auto i : ids ) { if ( ignore_orig && i == ids.front() ) continue; if ( id == i ) return Invalid(); // reassignment } return false; } bool CSE_ValidityChecker::CheckAggrMod(const TypePtr& t) { if ( ! IsAggr(t) ) return false; for ( auto i : ids ) if ( same_type(t, i->GetType()) ) return Invalid(); return false; } bool CSE_ValidityChecker::CheckRecordConstructor(const TypePtr& t) { if ( t->Tag() != TYPE_RECORD ) return false; return CheckSideEffects(SideEffectsOp::CONSTRUCTION, t); } bool CSE_ValidityChecker::CheckTableMod(const TypePtr& t) { if ( CheckAggrMod(t) ) return true; if ( t->Tag() != TYPE_TABLE ) return false; return CheckSideEffects(SideEffectsOp::WRITE, t); } bool CSE_ValidityChecker::CheckTableRef(const TypePtr& t) { return CheckSideEffects(SideEffectsOp::READ, t); } bool CSE_ValidityChecker::CheckCall(const CallExpr* c) { auto func = c->Func(); if ( func->Tag() != EXPR_NAME ) // Can't analyze indirect calls. return Invalid(); IDSet non_local_ids; TypeSet aggrs; bool is_unknown = false; auto resolved = pfs->GetCallSideEffects(func->AsNameExpr(), non_local_ids, aggrs, is_unknown); ASSERT(resolved); if ( is_unknown || CheckSideEffects(non_local_ids, aggrs) ) return Invalid(); return false; } bool CSE_ValidityChecker::CheckSideEffects(SideEffectsOp::AccessType access, const TypePtr& t) { IDSet non_local_ids; TypeSet aggrs; if ( pfs->GetSideEffects(access, t.get(), non_local_ids, aggrs) ) return Invalid(); return CheckSideEffects(non_local_ids, aggrs); } bool CSE_ValidityChecker::CheckSideEffects(const IDSet& non_local_ids, const TypeSet& aggrs) { if ( non_local_ids.empty() && aggrs.empty() ) // This is far and away the most common case. return false; for ( auto i : ids ) { for ( auto nli : non_local_ids ) if ( nli == i ) return Invalid(); auto i_t = i->GetType(); for ( auto a : aggrs ) if ( same_type(a, i_t.get()) ) return Invalid(); } return false; } } // namespace zeek::detail