script optimization support for "add" and "delete" being expressions

This commit is contained in:
Vern Paxson 2024-05-16 14:37:30 -07:00 committed by Tim Wojtulewicz
parent 0e5bece385
commit 37c1f6641c
11 changed files with 132 additions and 144 deletions

View file

@ -688,8 +688,6 @@ private:
void GenIfStmt(const IfStmt* i);
void GenWhileStmt(const WhileStmt* w);
void GenReturnStmt(const ReturnStmt* r);
void GenAddStmt(const ExprStmt* es);
void GenDeleteStmt(const ExprStmt* es);
void GenEventStmt(const EventStmt* ev);
void GenSwitchStmt(const SwitchStmt* sw);
@ -757,6 +755,8 @@ private:
std::string GenNameExpr(const NameExpr* ne, GenType gt);
std::string GenConstExpr(const ConstExpr* c, GenType gt);
std::string GenAggrAdd(const Expr* e);
std::string GenAggrDel(const Expr* e);
std::string GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level);
std::string GenCondExpr(const Expr* e, GenType gt);
std::string GenCallExpr(const CallExpr* c, GenType gt, bool top_level);

View file

@ -52,6 +52,9 @@ string CPPCompile::GenExpr(const Expr* e, GenType gt, bool top_level) {
gen = GenExpr(e->GetOp1(), GEN_VAL_PTR) + "->Clone()";
return GenericValPtrToGT(gen, e->GetType(), gt);
case EXPR_AGGR_ADD: return GenAggrAdd(e);
case EXPR_AGGR_DEL: return GenAggrDel(e);
case EXPR_INCR:
case EXPR_DECR: return GenIncrExpr(e, gt, e->Tag() == EXPR_INCR, top_level);
@ -183,6 +186,39 @@ string CPPCompile::GenConstExpr(const ConstExpr* c, GenType gt) {
return NativeToGT(GenVal(c->ValuePtr()), t, gt);
}
string CPPCompile::GenAggrAdd(const Expr* e) {
auto op = e->GetOp1();
auto aggr = GenExpr(op->GetOp1(), GEN_DONT_CARE);
auto indices = GenExpr(op->GetOp2(), GEN_VAL_PTR);
return "add_element__CPP(" + aggr + ", index_val__CPP({" + indices + "}))";
}
string CPPCompile::GenAggrDel(const Expr* e) {
auto op = e->GetOp1();
if ( op->Tag() == EXPR_NAME ) {
auto aggr_gen = GenExpr(op, GEN_VAL_PTR);
if ( op->GetType()->Tag() == TYPE_TABLE )
return aggr_gen + "->RemoveAll()";
else
return aggr_gen + "->Resize(0)";
}
auto aggr = op->GetOp1();
auto aggr_gen = GenExpr(aggr, GEN_VAL_PTR);
if ( op->Tag() == EXPR_INDEX ) {
auto indices = GenExpr(op->GetOp2(), GEN_VAL_PTR);
return "remove_element__CPP(" + aggr_gen + ", index_val__CPP({" + indices + "}))";
}
ASSERT(op->Tag() == EXPR_FIELD);
auto field = GenField(aggr, op->AsFieldExpr()->Field());
return aggr_gen + "->Remove(" + field + ")";
}
string CPPCompile::GenIncrExpr(const Expr* e, GenType gt, bool is_incr, bool top_level) {
// For compound operands (table indexing, record fields),
// Zeek's interpreter will actually evaluate the operand

View file

@ -39,10 +39,6 @@ void CPPCompile::GenStmt(const Stmt* s) {
case STMT_RETURN: GenReturnStmt(s->AsReturnStmt()); break;
case STMT_ADD: GenAddStmt(static_cast<const ExprStmt*>(s)); break;
case STMT_DELETE: GenDeleteStmt(static_cast<const ExprStmt*>(s)); break;
case STMT_EVENT: GenEventStmt(static_cast<const EventStmt*>(s)); break;
case STMT_SWITCH: GenSwitchStmt(static_cast<const SwitchStmt*>(s)); break;
@ -149,41 +145,6 @@ void CPPCompile::GenReturnStmt(const ReturnStmt* r) {
}
}
void CPPCompile::GenAddStmt(const ExprStmt* es) {
auto op = es->StmtExpr();
auto aggr = GenExpr(op->GetOp1(), GEN_DONT_CARE);
auto indices = op->GetOp2();
Emit("add_element__CPP(%s, index_val__CPP({%s}));", aggr, GenExpr(indices, GEN_VAL_PTR));
}
void CPPCompile::GenDeleteStmt(const ExprStmt* es) {
auto op = es->StmtExpr();
if ( op->Tag() == EXPR_NAME ) {
if ( op->GetType()->Tag() == TYPE_TABLE )
Emit("%s->RemoveAll();", GenExpr(op, GEN_VAL_PTR));
else
Emit("%s->Resize(0);", GenExpr(op, GEN_VAL_PTR));
return;
}
auto aggr = op->GetOp1();
auto aggr_gen = GenExpr(aggr, GEN_VAL_PTR);
if ( op->Tag() == EXPR_INDEX ) {
auto indices = op->GetOp2();
Emit("remove_element__CPP(%s, index_val__CPP({%s}));", aggr_gen, GenExpr(indices, GEN_VAL_PTR));
}
else {
ASSERT(op->Tag() == EXPR_FIELD);
auto field = GenField(aggr, op->AsFieldExpr()->Field());
Emit("%s->Remove(%s);", aggr_gen, field);
}
}
void CPPCompile::GenEventStmt(const EventStmt* ev) {
auto ev_s = ev->StmtExprPtr();
auto ev_e = cast_intrusive<EventExpr>(ev_s);

View file

@ -38,16 +38,6 @@ TraversalCode CSE_ValidityChecker::PreStmt(const Stmt* s) {
return TC_ABORTALL;
}
if ( t == STMT_ADD || t == STMT_DELETE )
in_aggr_mod_stmt = true;
return TC_CONTINUE;
}
TraversalCode CSE_ValidityChecker::PostStmt(const Stmt* s) {
if ( s->Tag() == STMT_ADD || s->Tag() == STMT_DELETE )
in_aggr_mod_stmt = false;
return TC_CONTINUE;
}
@ -120,6 +110,9 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
}
} break;
case EXPR_AGGR_ADD:
case EXPR_AGGR_DEL: ++in_aggr_mod_expr; break;
case EXPR_APPEND_TO:
// This doesn't directly change any identifiers, but does
// alter an aggregate.
@ -155,7 +148,7 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
auto aggr = e->GetOp1();
auto aggr_t = aggr->GetType();
if ( in_aggr_mod_stmt ) {
if ( in_aggr_mod_expr > 0 ) {
auto aggr_id = aggr->AsNameExpr()->Id();
if ( CheckID(aggr_id, true) || CheckAggrMod(aggr_t) )
@ -174,6 +167,13 @@ TraversalCode CSE_ValidityChecker::PreExpr(const Expr* e) {
return TC_CONTINUE;
}
TraversalCode CSE_ValidityChecker::PostExpr(const Expr* e) {
if ( have_start_e && (e->Tag() == EXPR_AGGR_ADD || e->Tag() == EXPR_AGGR_DEL) )
--in_aggr_mod_expr;
return TC_CONTINUE;
}
bool CSE_ValidityChecker::CheckID(const ID* id, bool ignore_orig) {
for ( auto i : ids ) {
if ( ignore_orig && i == ids.front() )

View file

@ -21,8 +21,8 @@ public:
const Expr* end_e);
TraversalCode PreStmt(const Stmt*) override;
TraversalCode PostStmt(const Stmt*) override;
TraversalCode PreExpr(const Expr*) override;
TraversalCode PostExpr(const Expr*) override;
// Returns the ultimate verdict re safety.
bool IsValid() const {
@ -99,10 +99,12 @@ protected:
bool have_start_e = false;
bool have_end_e = false;
// Whether analyzed expressions occur in the context of a statement
// Whether analyzed expressions occur in the context of an expression
// that modifies an aggregate ("add" or "delete"), which changes the
// interpretation of the expressions.
bool in_aggr_mod_stmt = false;
//
// A count to allow for nesting.
int in_aggr_mod_expr = 0;
};
// Used for debugging, to communicate which expression wasn't

View file

@ -157,17 +157,6 @@ TraversalCode ProfileFunc::PreStmt(const Stmt* s) {
expr_switches.insert(sw);
} break;
case STMT_ADD:
case STMT_DELETE: {
auto ad_stmt = static_cast<const AddDelStmt*>(s);
auto ad_e = ad_stmt->StmtExpr();
auto lhs = ad_e->GetOp1();
if ( lhs )
aggr_mods.insert(lhs->GetType().get());
else
aggr_mods.insert(ad_e->GetType().get());
} break;
default: break;
}
@ -339,6 +328,15 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) {
}
} break;
case EXPR_AGGR_ADD:
case EXPR_AGGR_DEL: {
auto lhs = e->GetOp1();
if ( lhs )
aggr_mods.insert(lhs->GetType().get());
else
aggr_mods.insert(e->GetType().get());
} break;
case EXPR_CALL: {
auto c = e->AsCallExpr();
auto args = c->Args();

View file

@ -486,26 +486,6 @@ bool SwitchStmt::CouldReturn(bool ignore_break) const {
return false;
}
bool AddDelStmt::IsReduced(Reducer* c) const { return e->HasReducedOps(c); }
StmtPtr AddDelStmt::DoReduce(Reducer* c) {
if ( c->Optimizing() ) {
e = c->OptExpr(e);
return ThisPtr();
}
auto red_e_stmt = e->ReduceToSingletons(c);
if ( red_e_stmt )
return TransformMe(make_intrusive<StmtList>(red_e_stmt, ThisPtr()), c);
else
return ThisPtr();
}
StmtPtr AddStmt::Duplicate() { return SetSucc(new AddStmt(e->Duplicate())); }
StmtPtr DelStmt::Duplicate() { return SetSucc(new DelStmt(e->Duplicate())); }
StmtPtr EventStmt::Duplicate() { return SetSucc(new EventStmt(e->Duplicate()->AsEventExprPtr())); }
StmtPtr EventStmt::DoReduce(Reducer* c) {

View file

@ -223,8 +223,6 @@ UDs UseDefs::PropagateUDs(const Stmt* s, UDs succ_UDs, const Stmt* succ_stmt, bo
case STMT_EVENT:
case STMT_CHECK_ANY_LEN:
case STMT_ADD:
case STMT_DELETE:
case STMT_RETURN: {
auto e = static_cast<const ExprStmt*>(s)->StmtExpr();
@ -436,6 +434,19 @@ UDs UseDefs::ExprUDs(const Expr* e) {
break;
}
case EXPR_AGGR_ADD:
case EXPR_AGGR_DEL: {
auto op = e->GetOp1();
if ( op->Tag() == EXPR_INDEX ) {
AddInExprUDs(uds, op->GetOp1().get());
auto rhs_UDs = ExprUDs(op->GetOp2().get());
uds = UD_Union(uds, rhs_UDs);
}
else
AddInExprUDs(uds, op.get());
break;
}
case EXPR_INCR:
case EXPR_DECR: AddInExprUDs(uds, e->GetOp1()->AsRefExprPtr()->GetOp1().get()); break;

View file

@ -137,8 +137,6 @@ private:
const ZAMStmt CompileExpr(const ExprStmt* es);
const ZAMStmt CompileIf(const IfStmt* is);
const ZAMStmt CompileSwitch(const SwitchStmt* sw);
const ZAMStmt CompileAdd(const AddStmt* as);
const ZAMStmt CompileDel(const DelStmt* ds);
const ZAMStmt CompileWhile(const WhileStmt* ws);
const ZAMStmt CompileFor(const ForStmt* f);
const ZAMStmt CompileReturn(const ReturnStmt* r);
@ -186,6 +184,8 @@ private:
const ZAMStmt CompileIncrExpr(const IncrExpr* e);
const ZAMStmt CompileAppendToExpr(const AppendToExpr* e);
const ZAMStmt CompileAdd(const AggrAddExpr* e);
const ZAMStmt CompileDel(const AggrDelExpr* e);
const ZAMStmt CompileAddToExpr(const AddToExpr* e);
const ZAMStmt CompileRemoveFromExpr(const RemoveFromExpr* e);
const ZAMStmt CompileAssignExpr(const AssignExpr* e);

View file

@ -16,6 +16,10 @@ const ZAMStmt ZAMCompiler::CompileExpr(const Expr* e) {
case EXPR_APPEND_TO: return CompileAppendToExpr(static_cast<const AppendToExpr*>(e));
case EXPR_AGGR_ADD: return CompileAdd(static_cast<const AggrAddExpr*>(e));
case EXPR_AGGR_DEL: return CompileDel(static_cast<const AggrDelExpr*>(e));
case EXPR_ADD_TO: return CompileAddToExpr(static_cast<const AddToExpr*>(e));
case EXPR_REMOVE_FROM: return CompileRemoveFromExpr(static_cast<const RemoveFromExpr*>(e));
@ -78,6 +82,56 @@ const ZAMStmt ZAMCompiler::CompileAppendToExpr(const AppendToExpr* e) {
return n2 ? AppendToVV(n1, n2) : AppendToVC(n1, cc);
}
const ZAMStmt ZAMCompiler::CompileAdd(const AggrAddExpr* e) {
auto op = e->GetOp1();
auto aggr = op->GetOp1()->AsNameExpr();
auto index_list = op->GetOp2();
if ( index_list->Tag() != EXPR_LIST )
reporter->InternalError("non-list in \"add\"");
auto indices = index_list->AsListExprPtr();
auto& exprs = indices->Exprs();
if ( exprs.length() == 1 ) {
auto e1 = exprs[0];
if ( e1->Tag() == EXPR_NAME )
return AddStmt1VV(aggr, e1->AsNameExpr());
else
return AddStmt1VC(aggr, e1->AsConstExpr());
}
return AddStmtVO(aggr, BuildVals(indices));
}
const ZAMStmt ZAMCompiler::CompileDel(const AggrDelExpr* e) {
auto op = e->GetOp1();
if ( op->Tag() == EXPR_NAME ) {
auto n = op->AsNameExpr();
if ( n->GetType()->Tag() == TYPE_TABLE )
return ClearTableV(n);
else
return ClearVectorV(n);
}
auto aggr = op->GetOp1()->AsNameExpr();
if ( op->Tag() == EXPR_FIELD ) {
int field = op->AsFieldExpr()->Field();
return DelFieldVi(aggr, field);
}
auto index_list = op->GetOp2();
if ( index_list->Tag() != EXPR_LIST )
reporter->InternalError("non-list in \"delete\"");
auto internal_ind = std::unique_ptr<OpaqueVals>(BuildVals(index_list->AsListExprPtr()));
return DelTableVO(aggr, internal_ind.get());
}
const ZAMStmt ZAMCompiler::CompileAddToExpr(const AddToExpr* e) {
auto op1 = e->GetOp1();
auto t1 = op1->GetType()->Tag();

View file

@ -28,10 +28,6 @@ const ZAMStmt ZAMCompiler::CompileStmt(const Stmt* s) {
case STMT_SWITCH: return CompileSwitch(static_cast<const SwitchStmt*>(s));
case STMT_ADD: return CompileAdd(static_cast<const AddStmt*>(s));
case STMT_DELETE: return CompileDel(static_cast<const DelStmt*>(s));
case STMT_EVENT: {
auto es = static_cast<const EventStmt*>(s);
auto e = static_cast<const EventExpr*>(es->StmtExpr());
@ -623,56 +619,6 @@ const ZAMStmt ZAMCompiler::TypeSwitch(const SwitchStmt* sw, const NameExpr* v, c
return body_end;
}
const ZAMStmt ZAMCompiler::CompileAdd(const AddStmt* as) {
auto e = as->StmtExprPtr();
auto aggr = e->GetOp1()->AsNameExpr();
auto index_list = e->GetOp2();
if ( index_list->Tag() != EXPR_LIST )
reporter->InternalError("non-list in \"add\"");
auto indices = index_list->AsListExprPtr();
auto& exprs = indices->Exprs();
if ( exprs.length() == 1 ) {
auto e1 = exprs[0];
if ( e1->Tag() == EXPR_NAME )
return AddStmt1VV(aggr, e1->AsNameExpr());
else
return AddStmt1VC(aggr, e1->AsConstExpr());
}
return AddStmtVO(aggr, BuildVals(indices));
}
const ZAMStmt ZAMCompiler::CompileDel(const DelStmt* ds) {
auto e = ds->StmtExprPtr();
if ( e->Tag() == EXPR_NAME ) {
auto n = e->AsNameExpr();
if ( n->GetType()->Tag() == TYPE_TABLE )
return ClearTableV(n);
else
return ClearVectorV(n);
}
auto aggr = e->GetOp1()->AsNameExpr();
if ( e->Tag() == EXPR_FIELD ) {
int field = e->AsFieldExpr()->Field();
return DelFieldVi(aggr, field);
}
auto index_list = e->GetOp2();
if ( index_list->Tag() != EXPR_LIST )
reporter->InternalError("non-list in \"delete\"");
auto internal_ind = std::unique_ptr<OpaqueVals>(BuildVals(index_list->AsListExprPtr()));
return DelTableVO(aggr, internal_ind.get());
}
const ZAMStmt ZAMCompiler::CompileWhile(const WhileStmt* ws) {
auto loop_condition = ws->Condition();