script optimization support

This commit is contained in:
Vern Paxson 2024-05-08 16:48:02 -07:00
parent 905ed55389
commit 265788b50b
12 changed files with 159 additions and 55 deletions

View file

@ -1403,7 +1403,7 @@ void AddExpr::Canonicalize() {
SwapOps(); SwapOps();
} }
AggrAddExpr::AggrAddExpr(ExprPtr _op) : UnaryExpr(EXPR_AGGR_ADD, std::move(_op)) { AggrAddExpr::AggrAddExpr(ExprPtr _op) : AggrAddDelExpr(EXPR_AGGR_ADD, std::move(_op)) {
if ( ! op->IsError() && ! op->CanAdd() ) if ( ! op->IsError() && ! op->CanAdd() )
ExprError("illegal add expression"); ExprError("illegal add expression");
@ -1412,7 +1412,7 @@ AggrAddExpr::AggrAddExpr(ExprPtr _op) : UnaryExpr(EXPR_AGGR_ADD, std::move(_op))
ValPtr AggrAddExpr::Eval(Frame* f) const { return op->Add(f); } ValPtr AggrAddExpr::Eval(Frame* f) const { return op->Add(f); }
AggrDelExpr::AggrDelExpr(ExprPtr _op) : UnaryExpr(EXPR_AGGR_DEL, std::move(_op)) { AggrDelExpr::AggrDelExpr(ExprPtr _op) : AggrAddDelExpr(EXPR_AGGR_DEL, std::move(_op)) {
if ( ! op->IsError() && ! op->CanDel() ) if ( ! op->IsError() && ! op->CanDel() )
Error("illegal delete expression"); Error("illegal delete expression");

View file

@ -707,31 +707,36 @@ protected:
ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2); ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2);
}; };
class AggrAddExpr final : public UnaryExpr { // A helper class that enables us to factor some common code.
class AggrAddDelExpr : public UnaryExpr {
public: public:
explicit AggrAddExpr(ExprPtr e); explicit AggrAddDelExpr(ExprTag _tag, ExprPtr _e) : UnaryExpr(_tag, std::move(_e)) {}
bool IsPure() const override { return false; } bool IsPure() const override { return false; }
// Optimization-related: // Optimization-related:
ExprPtr Duplicate() override;
bool IsReduced(Reducer* c) const override { return HasReducedOps(c); } bool IsReduced(Reducer* c) const override { return HasReducedOps(c); }
bool HasReducedOps(Reducer* c) const override { return op->HasReducedOps(c); }
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
};
class AggrAddExpr final : public AggrAddDelExpr {
public:
explicit AggrAddExpr(ExprPtr e);
// Optimization-related:
ExprPtr Duplicate() override;
protected: protected:
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
}; };
class AggrDelExpr final : public UnaryExpr { class AggrDelExpr final : public AggrAddDelExpr {
public: public:
explicit AggrDelExpr(ExprPtr e); explicit AggrDelExpr(ExprPtr e);
bool IsPure() const override { return false; }
// Optimization-related: // Optimization-related:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
bool IsReduced(Reducer* c) const override { return HasReducedOps(c); }
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
protected: protected:
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;

View file

@ -757,6 +757,8 @@ private:
std::string GenNameExpr(const NameExpr* ne, GenType gt); std::string GenNameExpr(const NameExpr* ne, GenType gt);
std::string GenConstExpr(const ConstExpr* c, GenType gt); std::string GenConstExpr(const ConstExpr* c, GenType gt);
std::string GenAggrAdd(const Expr* e);
std::string GenAggrDel(const Expr* e);
std::string GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level); std::string GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level);
std::string GenCondExpr(const Expr* e, GenType gt); std::string GenCondExpr(const Expr* e, GenType gt);
std::string GenCallExpr(const CallExpr* c, GenType gt, bool top_level); std::string GenCallExpr(const CallExpr* c, GenType gt, bool top_level);

View file

@ -52,6 +52,9 @@ string CPPCompile::GenExpr(const Expr* e, GenType gt, bool top_level) {
gen = GenExpr(e->GetOp1(), GEN_VAL_PTR) + "->Clone()"; gen = GenExpr(e->GetOp1(), GEN_VAL_PTR) + "->Clone()";
return GenericValPtrToGT(gen, e->GetType(), gt); return GenericValPtrToGT(gen, e->GetType(), gt);
case EXPR_AGGR_ADD: return GenAggrAdd(e);
case EXPR_AGGR_DEL: return GenAggrDel(e);
case EXPR_INCR: case EXPR_INCR:
case EXPR_DECR: return GenIncrExpr(e, gt, e->Tag() == EXPR_INCR, top_level); case EXPR_DECR: return GenIncrExpr(e, gt, e->Tag() == EXPR_INCR, top_level);
@ -183,6 +186,39 @@ string CPPCompile::GenConstExpr(const ConstExpr* c, GenType gt) {
return NativeToGT(GenVal(c->ValuePtr()), t, gt); return NativeToGT(GenVal(c->ValuePtr()), t, gt);
} }
string CPPCompile::GenAggrAdd(const Expr* e) {
auto op = e->GetOp1();
auto aggr = GenExpr(op->GetOp1(), GEN_DONT_CARE);
auto indices = GenExpr(op->GetOp2(), GEN_VAL_PTR);
return "add_element__CPP(" + aggr + ", index_val__CPP({" + indices + "}))";
}
string CPPCompile::GenAggrDel(const Expr* e) {
auto op = e->GetOp1();
if ( op->Tag() == EXPR_NAME ) {
auto aggr_gen = GenExpr(op, GEN_VAL_PTR);
if ( op->GetType()->Tag() == TYPE_TABLE )
return aggr_gen + "->RemoveAll()";
else
return aggr_gen + "->Resize(0)";
}
auto aggr = op->GetOp1();
auto aggr_gen = GenExpr(aggr, GEN_VAL_PTR);
if ( op->Tag() == EXPR_INDEX ) {
auto indices = GenExpr(op->GetOp2(), GEN_VAL_PTR);
return "remove_element__CPP(" + aggr_gen + ", index_val__CPP({" + indices + "}))";
}
ASSERT(op->Tag() == EXPR_FIELD);
auto field = GenField(aggr, op->AsFieldExpr()->Field());
return aggr_gen + "->Remove(" + field + ")";
}
string CPPCompile::GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level) { string CPPCompile::GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level) {
// For compound operands (table indexing, record fields), // For compound operands (table indexing, record fields),
// Zeek's interpreter will actually evaluate the operand // Zeek's interpreter will actually evaluate the operand

View file

@ -34,7 +34,9 @@ export ZEEK_OPT_FILES="testing/btest"
# export -n ZEEK_GEN_CPP ZEEK_CPP_DIR ZEEK_OPT_FUNCS ZEEK_OPT_FILES # export -n ZEEK_GEN_CPP ZEEK_CPP_DIR ZEEK_OPT_FUNCS ZEEK_OPT_FILES
unset ZEEK_GEN_CPP ZEEK_REPORT_UNCOMPILABLE ZEEK_CPP_DIR ZEEK_OPT_FILES unset ZEEK_GEN_CPP ZEEK_REPORT_UNCOMPILABLE ZEEK_CPP_DIR ZEEK_OPT_FILES
ls -l CPP-gen.cc
ninja ninja
ls -l src/zeek
( (
cd ../testing/btest cd ../testing/btest

View file

@ -38,16 +38,6 @@ TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) {
return TC_ABORTALL; return TC_ABORTALL;
} }
if ( t == STMT_ADD || t == STMT_DELETE )
in_aggr_mod_stmt = true;
return TC_CONTINUE;
}
TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) {
if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE )
in_aggr_mod_stmt = false;
return TC_CONTINUE; return TC_CONTINUE;
} }
@ -120,6 +110,9 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
} }
} break; } break;
case EXPR_AGGR_ADD:
case EXPR_AGGR_DEL: ++in_aggr_mod_expr; break;
case EXPR_APPEND_TO: case EXPR_APPEND_TO:
// This doesn't directly change any identifiers, but does // This doesn't directly change any identifiers, but does
// alter an aggregate. // alter an aggregate.
@ -155,7 +148,7 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
auto aggr = e->GetOp1(); auto aggr = e->GetOp1();
auto aggr_t = aggr->GetType(); auto aggr_t = aggr->GetType();
if ( in_aggr_mod_stmt ) { if ( in_aggr_mod_expr ) {
auto aggr_id = aggr->AsNameExpr()->Id(); auto aggr_id = aggr->AsNameExpr()->Id();
if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) ) if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) )
@ -174,6 +167,13 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
return TC_CONTINUE; return TC_CONTINUE;
} }
TraversalCode CSE_ValidityChecker::PostExpr(const Expr* e) {
if ( e->Tag() == EXPR_AGGR_ADD || e->Tag() == EXPR_AGGR_DEL )
--in_aggr_mod_expr;
return TC_CONTINUE;
}
bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) { bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) {
for ( auto i : ids ) { for ( auto i : ids ) {
if ( ignore_orig && i == ids.front() ) if ( ignore_orig && i == ids.front() )

View file

@ -21,8 +21,8 @@ public:
const Expr* end_e); const Expr* end_e);
TraversalCode PreStmt(const Stmt*) override; TraversalCode PreStmt(const Stmt*) override;
TraversalCode PostStmt(const Stmt*) override;
TraversalCode PreExpr(const Expr*) override; TraversalCode PreExpr(const Expr*) override;
TraversalCode PostExpr(const Expr*) override;
// Returns the ultimate verdict re safety. // Returns the ultimate verdict re safety.
bool IsValid() const { bool IsValid() const {
@ -99,10 +99,12 @@ protected:
bool have_start_e = false; bool have_start_e = false;
bool have_end_e = false; bool have_end_e = false;
// Whether analyzed expressions occur in the context of a statement // Whether analyzed expressions occur in the context of an expression
// that modifies an aggregate ("add" or "delete"), which changes the // that modifies an aggregate ("add" or "delete"), which changes the
// interpretation of the expressions. // interpretation of the expressions.
bool in_aggr_mod_stmt = false; //
// A count to allow for nesting.
int in_aggr_mod_expr = 0;
}; };
// Used for debugging, to communicate which expression wasn't // Used for debugging, to communicate which expression wasn't

View file

@ -723,30 +723,20 @@ ExprPtr AddExpr::BuildSub(const ExprPtr& op1, const ExprPtr& op2) {
return with_location_of(make_intrusive<SubExpr>(op1, rhs), this); return with_location_of(make_intrusive<SubExpr>(op1, rhs), this);
} }
ExprPtr AggrAddDelExpr::Reduce(Reducer* c, StmtPtr& red_stmt) {
if ( c->Optimizing() ) {
op = c->OptExpr(op);
return ThisPtr();
}
red_stmt = op->ReduceToSingletons(c);
return ThisPtr();
}
ExprPtr AggrAddExpr::Duplicate() { return SetSucc(new AggrAddExpr(op->Duplicate())); } ExprPtr AggrAddExpr::Duplicate() { return SetSucc(new AggrAddExpr(op->Duplicate())); }
ExprPtr AggrAddExpr::Reduce(Reducer* c, StmtPtr& red_stmt) {
if ( c->Optimizing() ) {
op = c->OptExpr(op);
return ThisPtr();
}
red_stmt = op->ReduceToSingletons(c);
return ThisPtr();
}
ExprPtr AggrDelExpr::Duplicate() { return SetSucc(new AggrDelExpr(op->Duplicate())); } ExprPtr AggrDelExpr::Duplicate() { return SetSucc(new AggrDelExpr(op->Duplicate())); }
ExprPtr AggrDelExpr::Reduce(Reducer* c, StmtPtr& red_stmt) {
if ( c->Optimizing() ) {
op = c->OptExpr(op);
return ThisPtr();
}
red_stmt = op->ReduceToSingletons(c);
return ThisPtr();
}
ExprPtr AddToExpr::Duplicate() { ExprPtr AddToExpr::Duplicate() {
auto op1_d = op1->Duplicate(); auto op1_d = op1->Duplicate();
auto op2_d = op2->Duplicate(); auto op2_d = op2->Duplicate();

View file

@ -157,17 +157,6 @@ TraversalCode ProfileFunc::PreStmt(const Stmt* s) {
expr_switches.insert(sw); expr_switches.insert(sw);
} break; } break;
case STMT_ADD:
case STMT_DELETE: {
auto ad_stmt = static_cast<const AddDelStmt*>(s);
auto ad_e = ad_stmt->StmtExpr();
auto lhs = ad_e->GetOp1();
if ( lhs )
aggr_mods.insert(lhs->GetType().get());
else
aggr_mods.insert(ad_e->GetType().get());
} break;
default: break; default: break;
} }
@ -339,6 +328,15 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) {
} }
} break; } break;
case EXPR_AGGR_ADD:
case EXPR_AGGR_DEL: {
auto lhs = e->GetOp1();
if ( lhs )
aggr_mods.insert(lhs->GetType().get());
else
aggr_mods.insert(e->GetType().get());
} break;
case EXPR_CALL: { case EXPR_CALL: {
auto c = e->AsCallExpr(); auto c = e->AsCallExpr();
auto args = c->Args(); auto args = c->Args();

View file

@ -436,6 +436,19 @@ UDs UseDefs::ExprUDs(const Expr* e) {
break; break;
} }
case EXPR_AGGR_ADD:
case EXPR_AGGR_DEL: {
auto op = e->GetOp1();
if ( op->Tag() == EXPR_INDEX ) {
AddInExprUDs(uds, op->GetOp1().get());
auto rhs_UDs = ExprUDs(op->GetOp2().get());
uds = UD_Union(uds, rhs_UDs);
}
else
AddInExprUDs(uds, op.get());
break;
}
case EXPR_INCR: case EXPR_INCR:
case EXPR_DECR: AddInExprUDs(uds, e->GetOp1()->AsRefExprPtr()->GetOp1().get()); break; case EXPR_DECR: AddInExprUDs(uds, e->GetOp1()->AsRefExprPtr()->GetOp1().get()); break;

View file

@ -186,6 +186,8 @@ private:
const ZAMStmt CompileIncrExpr(const IncrExpr* e); const ZAMStmt CompileIncrExpr(const IncrExpr* e);
const ZAMStmt CompileAppendToExpr(const AppendToExpr* e); const ZAMStmt CompileAppendToExpr(const AppendToExpr* e);
const ZAMStmt CompileAdd(const AggrAddExpr* e);
const ZAMStmt CompileDel(const AggrDelExpr* e);
const ZAMStmt CompileAddToExpr(const AddToExpr* e); const ZAMStmt CompileAddToExpr(const AddToExpr* e);
const ZAMStmt CompileRemoveFromExpr(const RemoveFromExpr* e); const ZAMStmt CompileRemoveFromExpr(const RemoveFromExpr* e);
const ZAMStmt CompileAssignExpr(const AssignExpr* e); const ZAMStmt CompileAssignExpr(const AssignExpr* e);

View file

@ -16,6 +16,10 @@ const ZAMStmt ZAMCompiler::CompileExpr(const Expr* e) {
case EXPR_APPEND_TO: return CompileAppendToExpr(static_cast<const AppendToExpr*>(e)); case EXPR_APPEND_TO: return CompileAppendToExpr(static_cast<const AppendToExpr*>(e));
case EXPR_AGGR_ADD: return CompileAdd(static_cast<const AggrAddExpr*>(e));
case EXPR_AGGR_DEL: return CompileDel(static_cast<const AggrDelExpr*>(e));
case EXPR_ADD_TO: return CompileAddToExpr(static_cast<const AddToExpr*>(e)); case EXPR_ADD_TO: return CompileAddToExpr(static_cast<const AddToExpr*>(e));
case EXPR_REMOVE_FROM: return CompileRemoveFromExpr(static_cast<const RemoveFromExpr*>(e)); case EXPR_REMOVE_FROM: return CompileRemoveFromExpr(static_cast<const RemoveFromExpr*>(e));
@ -78,6 +82,56 @@ const ZAMStmt ZAMCompiler::CompileAppendToExpr(const AppendToExpr* e) {
return n2 ? AppendToVV(n1, n2) : AppendToVC(n1, cc); return n2 ? AppendToVV(n1, n2) : AppendToVC(n1, cc);
} }
const ZAMStmt ZAMCompiler::CompileAdd(const AggrAddExpr* e) {
auto op = e->GetOp1();
auto aggr = op->GetOp1()->AsNameExpr();
auto index_list = op->GetOp2();
if ( index_list->Tag() != EXPR_LIST )
reporter->InternalError("non-list in \"add\"");
auto indices = index_list->AsListExprPtr();
auto& exprs = indices->Exprs();
if ( exprs.length() == 1 ) {
auto e1 = exprs[0];
if ( e1->Tag() == EXPR_NAME )
return AddStmt1VV(aggr, e1->AsNameExpr());
else
return AddStmt1VC(aggr, e1->AsConstExpr());
}
return AddStmtVO(aggr, BuildVals(indices));
}
const ZAMStmt ZAMCompiler::CompileDel(const AggrDelExpr* e) {
auto op = e->GetOp1();
if ( op->Tag() == EXPR_NAME ) {
auto n = op->AsNameExpr();
if ( n->GetType()->Tag() == TYPE_TABLE )
return ClearTableV(n);
else
return ClearVectorV(n);
}
auto aggr = op->GetOp1()->AsNameExpr();
if ( op->Tag() == EXPR_FIELD ) {
int field = op->AsFieldExpr()->Field();
return DelFieldVi(aggr, field);
}
auto index_list = op->GetOp2();
if ( index_list->Tag() != EXPR_LIST )
reporter->InternalError("non-list in \"delete\"");
auto internal_ind = std::unique_ptr<OpaqueVals>(BuildVals(index_list->AsListExprPtr()));
return DelTableVO(aggr, internal_ind.get());
}
const ZAMStmt ZAMCompiler::CompileAddToExpr(const AddToExpr* e) { const ZAMStmt ZAMCompiler::CompileAddToExpr(const AddToExpr* e) {
auto op1 = e->GetOp1(); auto op1 = e->GetOp1();
auto t1 = op1->GetType()->Tag(); auto t1 = op1->GetType()->Tag();