diff --git a/src/Expr.cc b/src/Expr.cc index 23a0863678..7827f50713 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -1403,7 +1403,7 @@ void AddExpr::Canonicalize() { SwapOps(); } -AggrAddExpr::AggrAddExpr(ExprPtr _op) : UnaryExpr(EXPR_AGGR_ADD, std::move(_op)) { +AggrAddExpr::AggrAddExpr(ExprPtr _op) : AggrAddDelExpr(EXPR_AGGR_ADD, std::move(_op)) { if ( ! op->IsError() && ! op->CanAdd() ) ExprError("illegal add expression"); @@ -1412,7 +1412,7 @@ AggrAddExpr::AggrAddExpr(ExprPtr _op) : UnaryExpr(EXPR_AGGR_ADD, std::move(_op)) ValPtr AggrAddExpr::Eval(Frame* f) const { return op->Add(f); } -AggrDelExpr::AggrDelExpr(ExprPtr _op) : UnaryExpr(EXPR_AGGR_DEL, std::move(_op)) { +AggrDelExpr::AggrDelExpr(ExprPtr _op) : AggrAddDelExpr(EXPR_AGGR_DEL, std::move(_op)) { if ( ! op->IsError() && ! op->CanDel() ) Error("illegal delete expression"); diff --git a/src/Expr.h b/src/Expr.h index 9ca5434c10..d937870e79 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -707,31 +707,36 @@ protected: ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2); }; -class AggrAddExpr final : public UnaryExpr { +// A helper class that enables us to factor some common code. +class AggrAddDelExpr : public UnaryExpr { public: - explicit AggrAddExpr(ExprPtr e); + explicit AggrAddDelExpr(ExprTag _tag, ExprPtr _e) : UnaryExpr(_tag, std::move(_e)) {} bool IsPure() const override { return false; } // Optimization-related: - ExprPtr Duplicate() override; bool IsReduced(Reducer* c) const override { return HasReducedOps(c); } + bool HasReducedOps(Reducer* c) const override { return op->HasReducedOps(c); } ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; +}; + +class AggrAddExpr final : public AggrAddDelExpr { +public: + explicit AggrAddExpr(ExprPtr e); + + // Optimization-related: + ExprPtr Duplicate() override; protected: ValPtr Eval(Frame* f) const override; }; -class AggrDelExpr final : public UnaryExpr { +class AggrDelExpr final : public AggrAddDelExpr { public: explicit AggrDelExpr(ExprPtr e); - bool IsPure() const override { return false; } - // Optimization-related: ExprPtr Duplicate() override; - bool IsReduced(Reducer* c) const override { return HasReducedOps(c); } - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: ValPtr Eval(Frame* f) const override; diff --git a/src/script_opt/CPP/Compile.h b/src/script_opt/CPP/Compile.h index 5b05cc07e9..fb7889e6f4 100644 --- a/src/script_opt/CPP/Compile.h +++ b/src/script_opt/CPP/Compile.h @@ -757,6 +757,8 @@ private: std::string GenNameExpr(const NameExpr* ne, GenType gt); std::string GenConstExpr(const ConstExpr* c, GenType gt); + std::string GenAggrAdd(const Expr* e); + std::string GenAggrDel(const Expr* e); std::string GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level); std::string GenCondExpr(const Expr* e, GenType gt); std::string GenCallExpr(const CallExpr* c, GenType gt, bool top_level); diff --git a/src/script_opt/CPP/Exprs.cc b/src/script_opt/CPP/Exprs.cc index 24ab36f4b4..40eba322b6 100644 --- a/src/script_opt/CPP/Exprs.cc +++ b/src/script_opt/CPP/Exprs.cc @@ -52,6 +52,9 @@ string CPPCompile::GenExpr(const Expr* e, GenType gt, bool top_level) { gen = GenExpr(e->GetOp1(), GEN_VAL_PTR) + "->Clone()"; return GenericValPtrToGT(gen, e->GetType(), gt); + case EXPR_AGGR_ADD: return GenAggrAdd(e); + case EXPR_AGGR_DEL: return GenAggrDel(e); + case EXPR_INCR: case EXPR_DECR: return GenIncrExpr(e, gt, e->Tag() == EXPR_INCR, top_level); @@ -183,6 +186,39 @@ string CPPCompile::GenConstExpr(const ConstExpr* c, GenType gt) { return NativeToGT(GenVal(c->ValuePtr()), t, gt); } +string CPPCompile::GenAggrAdd(const Expr* e) { + auto op = e->GetOp1(); + auto aggr = GenExpr(op->GetOp1(), GEN_DONT_CARE); + auto indices = GenExpr(op->GetOp2(), GEN_VAL_PTR); + + return "add_element__CPP(" + aggr + ", index_val__CPP({" + indices + "}))"; +} + +string CPPCompile::GenAggrDel(const Expr* e) { + auto op = e->GetOp1(); + + if ( op->Tag() == EXPR_NAME ) { + auto aggr_gen = GenExpr(op, GEN_VAL_PTR); + + if ( op->GetType()->Tag() == TYPE_TABLE ) + return aggr_gen + "->RemoveAll()"; + else + return aggr_gen + "->Resize(0)"; + } + + auto aggr = op->GetOp1(); + auto aggr_gen = GenExpr(aggr, GEN_VAL_PTR); + + if ( op->Tag() == EXPR_INDEX ) { + auto indices = GenExpr(op->GetOp2(), GEN_VAL_PTR); + return "remove_element__CPP(" + aggr_gen + ", index_val__CPP({" + indices + "}))"; + } + + ASSERT(op->Tag() == EXPR_FIELD); + auto field = GenField(aggr, op->AsFieldExpr()->Field()); + return aggr_gen + "->Remove(" + field + ")"; +} + string CPPCompile::GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level) { // For compound operands (table indexing, record fields), // Zeek's interpreter will actually evaluate the operand diff --git a/src/script_opt/CPP/maint/do-CPP-btest.sh b/src/script_opt/CPP/maint/do-CPP-btest.sh index 5c3aff7c62..b7a7e47489 100755 --- a/src/script_opt/CPP/maint/do-CPP-btest.sh +++ b/src/script_opt/CPP/maint/do-CPP-btest.sh @@ -34,7 +34,9 @@ export ZEEK_OPT_FILES="testing/btest" # export -n ZEEK_GEN_CPP ZEEK_CPP_DIR ZEEK_OPT_FUNCS ZEEK_OPT_FILES unset ZEEK_GEN_CPP ZEEK_REPORT_UNCOMPILABLE ZEEK_CPP_DIR ZEEK_OPT_FILES +ls -l CPP-gen.cc ninja +ls -l src/zeek ( cd ../testing/btest diff --git a/src/script_opt/CSE.cc b/src/script_opt/CSE.cc index 2d97038670..3e217ee322 100644 --- a/src/script_opt/CSE.cc +++ b/src/script_opt/CSE.cc @@ -38,16 +38,6 @@ TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) { return TC_ABORTALL; } - if ( t == STMT_ADD || t == STMT_DELETE ) - in_aggr_mod_stmt = true; - - return TC_CONTINUE; -} - -TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) { - if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE ) - in_aggr_mod_stmt = false; - return TC_CONTINUE; } @@ -120,6 +110,9 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) { } } break; + case EXPR_AGGR_ADD: + case EXPR_AGGR_DEL: ++in_aggr_mod_expr; break; + case EXPR_APPEND_TO: // This doesn't directly change any identifiers, but does // alter an aggregate. @@ -155,7 +148,7 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) { auto aggr = e->GetOp1(); auto aggr_t = aggr->GetType(); - if ( in_aggr_mod_stmt ) { + if ( in_aggr_mod_expr ) { auto aggr_id = aggr->AsNameExpr()->Id(); if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) ) @@ -174,6 +167,13 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) { return TC_CONTINUE; } +TraversalCode CSE_ValidityChecker::PostExpr(const Expr* e) { + if ( e->Tag() == EXPR_AGGR_ADD || e->Tag() == EXPR_AGGR_DEL ) + --in_aggr_mod_expr; + + return TC_CONTINUE; +} + bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) { for ( auto i : ids ) { if ( ignore_orig && i == ids.front() ) diff --git a/src/script_opt/CSE.h b/src/script_opt/CSE.h index 9507cd493e..56d6572d9a 100644 --- a/src/script_opt/CSE.h +++ b/src/script_opt/CSE.h @@ -21,8 +21,8 @@ public: const Expr* end_e); TraversalCode PreStmt(const Stmt*) override; - TraversalCode PostStmt(const Stmt*) override; TraversalCode PreExpr(const Expr*) override; + TraversalCode PostExpr(const Expr*) override; // Returns the ultimate verdict re safety. bool IsValid() const { @@ -99,10 +99,12 @@ protected: bool have_start_e = false; bool have_end_e = false; - // Whether analyzed expressions occur in the context of a statement + // Whether analyzed expressions occur in the context of an expression // that modifies an aggregate ("add" or "delete"), which changes the // interpretation of the expressions. - bool in_aggr_mod_stmt = false; + // + // A count to allow for nesting. + int in_aggr_mod_expr = 0; }; // Used for debugging, to communicate which expression wasn't diff --git a/src/script_opt/Expr.cc b/src/script_opt/Expr.cc index f597d4e969..9985448cd6 100644 --- a/src/script_opt/Expr.cc +++ b/src/script_opt/Expr.cc @@ -723,30 +723,20 @@ ExprPtr AddExpr::BuildSub(const ExprPtr& op1, const ExprPtr& op2) { return with_location_of(make_intrusive(op1, rhs), this); } +ExprPtr AggrAddDelExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { + if ( c->Optimizing() ) { + op = c->OptExpr(op); + return ThisPtr(); + } + + red_stmt = op->ReduceToSingletons(c); + return ThisPtr(); +} + ExprPtr AggrAddExpr::Duplicate() { return SetSucc(new AggrAddExpr(op->Duplicate())); } -ExprPtr AggrAddExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { - if ( c->Optimizing() ) { - op = c->OptExpr(op); - return ThisPtr(); - } - - red_stmt = op->ReduceToSingletons(c); - return ThisPtr(); -} - ExprPtr AggrDelExpr::Duplicate() { return SetSucc(new AggrDelExpr(op->Duplicate())); } -ExprPtr AggrDelExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { - if ( c->Optimizing() ) { - op = c->OptExpr(op); - return ThisPtr(); - } - - red_stmt = op->ReduceToSingletons(c); - return ThisPtr(); -} - ExprPtr AddToExpr::Duplicate() { auto op1_d = op1->Duplicate(); auto op2_d = op2->Duplicate(); diff --git a/src/script_opt/ProfileFunc.cc b/src/script_opt/ProfileFunc.cc index 29a0b02005..b10fbda940 100644 --- a/src/script_opt/ProfileFunc.cc +++ b/src/script_opt/ProfileFunc.cc @@ -157,17 +157,6 @@ TraversalCode ProfileFunc::PreStmt(const Stmt* s) { expr_switches.insert(sw); } break; - case STMT_ADD: - case STMT_DELETE: { - auto ad_stmt = static_cast(s); - auto ad_e = ad_stmt->StmtExpr(); - auto lhs = ad_e->GetOp1(); - if ( lhs ) - aggr_mods.insert(lhs->GetType().get()); - else - aggr_mods.insert(ad_e->GetType().get()); - } break; - default: break; } @@ -339,6 +328,15 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) { } } break; + case EXPR_AGGR_ADD: + case EXPR_AGGR_DEL: { + auto lhs = e->GetOp1(); + if ( lhs ) + aggr_mods.insert(lhs->GetType().get()); + else + aggr_mods.insert(e->GetType().get()); + } break; + case EXPR_CALL: { auto c = e->AsCallExpr(); auto args = c->Args(); diff --git a/src/script_opt/UseDefs.cc b/src/script_opt/UseDefs.cc index 1e786c9403..c9d17a3038 100644 --- a/src/script_opt/UseDefs.cc +++ b/src/script_opt/UseDefs.cc @@ -436,6 +436,19 @@ UDs UseDefs::ExprUDs(const Expr* e) { break; } + case EXPR_AGGR_ADD: + case EXPR_AGGR_DEL: { + auto op = e->GetOp1(); + if ( op->Tag() == EXPR_INDEX ) { + AddInExprUDs(uds, op->GetOp1().get()); + auto rhs_UDs = ExprUDs(op->GetOp2().get()); + uds = UD_Union(uds, rhs_UDs); + } + else + AddInExprUDs(uds, op.get()); + break; + } + case EXPR_INCR: case EXPR_DECR: AddInExprUDs(uds, e->GetOp1()->AsRefExprPtr()->GetOp1().get()); break; diff --git a/src/script_opt/ZAM/Compile.h b/src/script_opt/ZAM/Compile.h index 2ee365e64c..dbe50dc2d9 100644 --- a/src/script_opt/ZAM/Compile.h +++ b/src/script_opt/ZAM/Compile.h @@ -186,6 +186,8 @@ private: const ZAMStmt CompileIncrExpr(const IncrExpr* e); const ZAMStmt CompileAppendToExpr(const AppendToExpr* e); + const ZAMStmt CompileAdd(const AggrAddExpr* e); + const ZAMStmt CompileDel(const AggrDelExpr* e); const ZAMStmt CompileAddToExpr(const AddToExpr* e); const ZAMStmt CompileRemoveFromExpr(const RemoveFromExpr* e); const ZAMStmt CompileAssignExpr(const AssignExpr* e); diff --git a/src/script_opt/ZAM/Expr.cc b/src/script_opt/ZAM/Expr.cc index facc6952ff..02eefb9939 100644 --- a/src/script_opt/ZAM/Expr.cc +++ b/src/script_opt/ZAM/Expr.cc @@ -16,6 +16,10 @@ const ZAMStmt ZAMCompiler::CompileExpr(const Expr* e) { case EXPR_APPEND_TO: return CompileAppendToExpr(static_cast(e)); + case EXPR_AGGR_ADD: return CompileAdd(static_cast(e)); + + case EXPR_AGGR_DEL: return CompileDel(static_cast(e)); + case EXPR_ADD_TO: return CompileAddToExpr(static_cast(e)); case EXPR_REMOVE_FROM: return CompileRemoveFromExpr(static_cast(e)); @@ -78,6 +82,56 @@ const ZAMStmt ZAMCompiler::CompileAppendToExpr(const AppendToExpr* e) { return n2 ? AppendToVV(n1, n2) : AppendToVC(n1, cc); } +const ZAMStmt ZAMCompiler::CompileAdd(const AggrAddExpr* e) { + auto op = e->GetOp1(); + auto aggr = op->GetOp1()->AsNameExpr(); + auto index_list = op->GetOp2(); + + if ( index_list->Tag() != EXPR_LIST ) + reporter->InternalError("non-list in \"add\""); + + auto indices = index_list->AsListExprPtr(); + auto& exprs = indices->Exprs(); + + if ( exprs.length() == 1 ) { + auto e1 = exprs[0]; + if ( e1->Tag() == EXPR_NAME ) + return AddStmt1VV(aggr, e1->AsNameExpr()); + else + return AddStmt1VC(aggr, e1->AsConstExpr()); + } + + return AddStmtVO(aggr, BuildVals(indices)); +} + +const ZAMStmt ZAMCompiler::CompileDel(const AggrDelExpr* e) { + auto op = e->GetOp1(); + + if ( op->Tag() == EXPR_NAME ) { + auto n = op->AsNameExpr(); + + if ( n->GetType()->Tag() == TYPE_TABLE ) + return ClearTableV(n); + else + return ClearVectorV(n); + } + + auto aggr = op->GetOp1()->AsNameExpr(); + + if ( op->Tag() == EXPR_FIELD ) { + int field = op->AsFieldExpr()->Field(); + return DelFieldVi(aggr, field); + } + + auto index_list = op->GetOp2(); + + if ( index_list->Tag() != EXPR_LIST ) + reporter->InternalError("non-list in \"delete\""); + + auto internal_ind = std::unique_ptr(BuildVals(index_list->AsListExprPtr())); + return DelTableVO(aggr, internal_ind.get()); +} + const ZAMStmt ZAMCompiler::CompileAddToExpr(const AddToExpr* e) { auto op1 = e->GetOp1(); auto t1 = op1->GetType()->Tag();