mirror of
https://github.com/zeek/zeek.git
synced 2025-10-09 10:08:20 +00:00
move prototype scaffolding into regular member functions
This commit is contained in:
parent
694095c56f
commit
3a0b9325cc
2 changed files with 84 additions and 22 deletions
41
src/Stmt.h
41
src/Stmt.h
|
@ -480,6 +480,47 @@ public:
|
||||||
protected:
|
protected:
|
||||||
bool IsPure() const override;
|
bool IsPure() const override;
|
||||||
|
|
||||||
|
// These are used for script optimization, to find sequences of
|
||||||
|
// record assignments that form a chain that can be collapsed into
|
||||||
|
// specialized expressions/operations.
|
||||||
|
|
||||||
|
// Starting a position i, looks for a chain of record assignments.
|
||||||
|
// Returns a position just past where the chain ends, so a return
|
||||||
|
// value of i means "not an assignment chain".
|
||||||
|
//
|
||||||
|
// At this point, chains are simply a series of assignment to the
|
||||||
|
// same record, but different fields, with the only restriction being
|
||||||
|
// that the record is a simple variable and not a compound like "x$a$b =".
|
||||||
|
//
|
||||||
|
// Note that chains can have length 1, which is still useful for
|
||||||
|
// optimization in some circumstances.
|
||||||
|
unsigned int FindRecAssignmentChain(unsigned int i) const;
|
||||||
|
|
||||||
|
// For an assignment chain, maps RHS identifiers to their collection
|
||||||
|
// of operations (all of those that are of a type we know how to
|
||||||
|
// optimize, and that use the same RHS). The operations are captured
|
||||||
|
// as the underlying statements, which turns out to be convenient..
|
||||||
|
using OpChain = std::map<const ID*, std::vector<const Stmt*>>;
|
||||||
|
|
||||||
|
// For a given statement s that's part of an assignment chain,
|
||||||
|
// updates its corresponding OpChain, either the one for "x$a = y$b"
|
||||||
|
// ("assign") or "x$a += y$b" ("add"). Note that for this latter,
|
||||||
|
// the actual AST is "x$a = x$a + y$b".
|
||||||
|
void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpChain& add_chains) const;
|
||||||
|
|
||||||
|
// Given an OpChain, transform it into one or more custom expressions
|
||||||
|
// for evaluating it.The tag t indicates whether this chain is
|
||||||
|
// a set of assignments or +='s. The statements in the chain should
|
||||||
|
// all be found in chain_stmts, and will be removed from it.
|
||||||
|
StmtPtr TransformChain(const OpChain& c, ExprTag t, std::set<const Stmt*>& chain_stmts) const;
|
||||||
|
|
||||||
|
// Simplify the chain that runs from "start" to "end" by collapsing
|
||||||
|
// subsets of it into specialized operations. These are added to
|
||||||
|
// f_stmts, as are the statements in the chain that don't correspond
|
||||||
|
// to collapsible subsets. Returns true if simplification occurred,
|
||||||
|
// false if not.
|
||||||
|
bool SimplifyChain(unsigned int start, unsigned int end, std::vector<StmtPtr>& f_stmts) const;
|
||||||
|
|
||||||
std::vector<StmtPtr> stmts;
|
std::vector<StmtPtr> stmts;
|
||||||
|
|
||||||
// Optimization-related:
|
// Optimization-related:
|
||||||
|
|
|
@ -760,42 +760,50 @@ StmtPtr StmtList::DoReduce(Reducer* c) {
|
||||||
return ThisPtr();
|
return ThisPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns just past the end of the chain.
|
unsigned int StmtList::FindRecAssignmentChain(unsigned int i) const {
|
||||||
static unsigned int FindAssignmentChain(const std::vector<StmtPtr>& stmts, unsigned int i) {
|
|
||||||
const NameExpr* targ_rec = nullptr;
|
const NameExpr* targ_rec = nullptr;
|
||||||
std::set<int> fields_seen;
|
std::set<int> fields_seen;
|
||||||
|
|
||||||
for ( ; i < stmts.size(); ++i ) {
|
for ( ; i < stmts.size(); ++i ) {
|
||||||
auto& s = stmts[i];
|
auto& s = stmts[i];
|
||||||
|
|
||||||
|
// We're looking for either "x$a = y$b" or "x$a = x$a + y$b".
|
||||||
if ( s->Tag() != STMT_EXPR )
|
if ( s->Tag() != STMT_EXPR )
|
||||||
|
// No way it's an assignment.
|
||||||
return i;
|
return i;
|
||||||
|
|
||||||
auto se = s->AsExprStmt()->StmtExpr();
|
auto se = s->AsExprStmt()->StmtExpr();
|
||||||
if ( se->Tag() != EXPR_ASSIGN )
|
if ( se->Tag() != EXPR_ASSIGN )
|
||||||
return i;
|
return i;
|
||||||
|
|
||||||
|
// The LHS of an assignment starts with a RefExpr.
|
||||||
auto lhs_ref = se->GetOp1();
|
auto lhs_ref = se->GetOp1();
|
||||||
ASSERT(lhs_ref->Tag() == EXPR_REF);
|
ASSERT(lhs_ref->Tag() == EXPR_REF);
|
||||||
|
|
||||||
auto lhs = lhs_ref->GetOp1();
|
auto lhs = lhs_ref->GetOp1();
|
||||||
if ( lhs->Tag() != EXPR_FIELD )
|
if ( lhs->Tag() != EXPR_FIELD )
|
||||||
|
// Not of the form "x$a = ...".
|
||||||
return i;
|
return i;
|
||||||
|
|
||||||
auto lhs_field = lhs->AsFieldExpr()->Field();
|
auto lhs_field = lhs->AsFieldExpr()->Field();
|
||||||
if ( fields_seen.count(lhs_field) > 0 )
|
if ( fields_seen.count(lhs_field) > 0 )
|
||||||
|
// Earlier in this chain we've already seen "x$a", so end the
|
||||||
|
// chain at this repeated use because it's no longer a simple
|
||||||
|
// block of field assignments.
|
||||||
return i;
|
return i;
|
||||||
|
|
||||||
fields_seen.insert(lhs_field);
|
fields_seen.insert(lhs_field);
|
||||||
|
|
||||||
auto lhs_rec = lhs->GetOp1();
|
auto lhs_rec = lhs->GetOp1();
|
||||||
if ( lhs_rec->Tag() != EXPR_NAME )
|
if ( lhs_rec->Tag() != EXPR_NAME )
|
||||||
// Not a simple field reference.
|
// Not a simple field reference, e.g. "x$y$a".
|
||||||
return i;
|
return i;
|
||||||
|
|
||||||
auto lhs_rec_n = lhs_rec->AsNameExpr();
|
auto lhs_rec_n = lhs_rec->AsNameExpr();
|
||||||
|
|
||||||
if ( targ_rec ) {
|
if ( targ_rec ) {
|
||||||
if ( lhs_rec_n->Id() != targ_rec->Id() )
|
if ( lhs_rec_n->Id() != targ_rec->Id() )
|
||||||
|
// It's no longer "x$..." but some new variable "z$...".
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -805,21 +813,24 @@ static unsigned int FindAssignmentChain(const std::vector<StmtPtr>& stmts, unsig
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Maps RHS identifiers to their collection of operations, expressed
|
void StmtList::UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpChain& add_chains) const {
|
||||||
// as the underlying statement.
|
|
||||||
using OpChain = std::map<const ID*, std::vector<const Stmt*>>;
|
|
||||||
|
|
||||||
static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpChain& add_chains) {
|
|
||||||
auto se = s->AsExprStmt()->StmtExpr();
|
auto se = s->AsExprStmt()->StmtExpr();
|
||||||
ASSERT(se->Tag() == EXPR_ASSIGN);
|
ASSERT(se->Tag() == EXPR_ASSIGN);
|
||||||
|
|
||||||
|
// We dig three times into the LHS. The first gets the EXPR_ASSIGN's
|
||||||
|
// first operand, which is a RefExpr; the second gets its operand,
|
||||||
|
// which we've guaranteed in FindRecAssignmentChain is a FieldExpr;
|
||||||
|
// and the third is the FieldExpr's operand, which we've guaranteed
|
||||||
|
// is a NameExpr.
|
||||||
auto lhs_id = se->GetOp1()->GetOp1()->GetOp1()->AsNameExpr()->Id();
|
auto lhs_id = se->GetOp1()->GetOp1()->GetOp1()->AsNameExpr()->Id();
|
||||||
auto rhs = se->GetOp2();
|
auto rhs = se->GetOp2();
|
||||||
const FieldExpr* f;
|
const FieldExpr* f;
|
||||||
OpChain* c;
|
OpChain* c;
|
||||||
|
|
||||||
|
// Check whether RHS is either "y$b" or "x$a + y$b".
|
||||||
|
|
||||||
if ( rhs->Tag() == EXPR_ADD ) {
|
if ( rhs->Tag() == EXPR_ADD ) {
|
||||||
auto rhs_op1 = rhs->GetOp1();
|
auto rhs_op1 = rhs->GetOp1(); // need to see that it's "x$a"
|
||||||
|
|
||||||
if ( rhs_op1->Tag() != EXPR_FIELD )
|
if ( rhs_op1->Tag() != EXPR_FIELD )
|
||||||
return;
|
return;
|
||||||
|
@ -828,7 +839,7 @@ static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpC
|
||||||
if ( rhs_op1_rec->Tag() != EXPR_NAME || rhs_op1_rec->AsNameExpr()->Id() != lhs_id )
|
if ( rhs_op1_rec->Tag() != EXPR_NAME || rhs_op1_rec->AsNameExpr()->Id() != lhs_id )
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto rhs_op2 = rhs->GetOp2();
|
auto rhs_op2 = rhs->GetOp2(); // need to see that it's "y$b"
|
||||||
if ( rhs_op2->Tag() != EXPR_FIELD )
|
if ( rhs_op2->Tag() != EXPR_FIELD )
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
@ -842,12 +853,15 @@ static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpC
|
||||||
}
|
}
|
||||||
|
|
||||||
else
|
else
|
||||||
|
// Not a RHS we know how to leverage.
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto f_rec = f->GetOp1();
|
auto f_rec = f->GetOp1();
|
||||||
if ( f_rec->Tag() != EXPR_NAME )
|
if ( f_rec->Tag() != EXPR_NAME )
|
||||||
|
// Not a simple RHS, instead something like "y$z$b".
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
// If we get here, it's a keeper, record the associated statement.
|
||||||
auto id = f_rec->AsNameExpr()->Id();
|
auto id = f_rec->AsNameExpr()->Id();
|
||||||
auto cf = c->find(id);
|
auto cf = c->find(id);
|
||||||
if ( cf == c->end() )
|
if ( cf == c->end() )
|
||||||
|
@ -856,12 +870,17 @@ static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpC
|
||||||
cf->second.push_back(s.get());
|
cf->second.push_back(s.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
static StmtPtr TransformChain(const OpChain& c, ExprTag t, std::set<const Stmt*>& chain_stmts, StmtPtr s0) {
|
StmtPtr StmtList::TransformChain(const OpChain& c, ExprTag t, std::set<const Stmt*>& chain_stmts) const {
|
||||||
auto sl = with_location_of(make_intrusive<StmtList>(), s0);
|
IntrusivePtr<StmtList> sl;
|
||||||
auto& stmts = sl->Stmts();
|
|
||||||
|
|
||||||
for ( auto& id_stmts : c ) {
|
for ( auto& id_stmts : c ) {
|
||||||
auto orig_s = id_stmts.second;
|
auto orig_s = id_stmts.second;
|
||||||
|
|
||||||
|
if ( ! sl )
|
||||||
|
// Now that we have a statement, create our list and associate
|
||||||
|
// its location with the statement.
|
||||||
|
sl = with_location_of(make_intrusive<StmtList>(), orig_s[0]);
|
||||||
|
|
||||||
ExprPtr e;
|
ExprPtr e;
|
||||||
if ( t == EXPR_ASSIGN )
|
if ( t == EXPR_ASSIGN )
|
||||||
e = make_intrusive<AssignRecordFields>(orig_s, chain_stmts);
|
e = make_intrusive<AssignRecordFields>(orig_s, chain_stmts);
|
||||||
|
@ -870,14 +889,13 @@ static StmtPtr TransformChain(const OpChain& c, ExprTag t, std::set<const Stmt*>
|
||||||
|
|
||||||
e->SetLocationInfo(sl->GetLocationInfo());
|
e->SetLocationInfo(sl->GetLocationInfo());
|
||||||
auto es = with_location_of(make_intrusive<ExprStmt>(std::move(e)), sl);
|
auto es = with_location_of(make_intrusive<ExprStmt>(std::move(e)), sl);
|
||||||
stmts.emplace_back(std::move(es));
|
sl->Stmts().emplace_back(std::move(es));
|
||||||
}
|
}
|
||||||
|
|
||||||
return sl;
|
return sl;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool SimplifyChain(const std::vector<StmtPtr>& stmts, unsigned int start, unsigned int end,
|
bool StmtList::SimplifyChain(unsigned int start, unsigned int end, std::vector<StmtPtr>& f_stmts) const {
|
||||||
std::vector<StmtPtr>& f_stmts) {
|
|
||||||
OpChain assign_chains;
|
OpChain assign_chains;
|
||||||
OpChain add_chains;
|
OpChain add_chains;
|
||||||
std::set<const Stmt*> chain_stmts;
|
std::set<const Stmt*> chain_stmts;
|
||||||
|
@ -888,8 +906,9 @@ static bool SimplifyChain(const std::vector<StmtPtr>& stmts, unsigned int start,
|
||||||
UpdateAssignmentChains(s, assign_chains, add_chains);
|
UpdateAssignmentChains(s, assign_chains, add_chains);
|
||||||
}
|
}
|
||||||
|
|
||||||
// An add-chain of any size is a win. For an assign-chain to be
|
// An add-chain of any size is a win. For an assign-chain to be a win,
|
||||||
// a win, it needs to have at least two elements.
|
// it needs to have at least two elements, because a single "x$a = y$b"
|
||||||
|
// can be expressed using one ZAM instructino (but "x$a += y$b" cannot).
|
||||||
if ( add_chains.empty() ) {
|
if ( add_chains.empty() ) {
|
||||||
bool have_useful_assign_chain = false;
|
bool have_useful_assign_chain = false;
|
||||||
for ( auto& ac : assign_chains )
|
for ( auto& ac : assign_chains )
|
||||||
|
@ -899,12 +918,14 @@ static bool SimplifyChain(const std::vector<StmtPtr>& stmts, unsigned int start,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( ! have_useful_assign_chain )
|
if ( ! have_useful_assign_chain )
|
||||||
|
// No gains available.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
f_stmts.push_back(TransformChain(assign_chains, EXPR_ASSIGN, chain_stmts, stmts[start]));
|
f_stmts.push_back(TransformChain(assign_chains, EXPR_ASSIGN, chain_stmts));
|
||||||
f_stmts.push_back(TransformChain(add_chains, EXPR_ADD, chain_stmts, stmts[start]));
|
f_stmts.push_back(TransformChain(add_chains, EXPR_ADD, chain_stmts));
|
||||||
|
|
||||||
|
// At this point, chain_stmts has only the remainders that weren't removed.
|
||||||
for ( auto s : stmts )
|
for ( auto s : stmts )
|
||||||
if ( chain_stmts.count(s.get()) > 0 )
|
if ( chain_stmts.count(s.get()) > 0 )
|
||||||
f_stmts.push_back(s);
|
f_stmts.push_back(s);
|
||||||
|
@ -917,8 +938,8 @@ bool StmtList::ReduceStmt(unsigned int& s_i, std::vector<StmtPtr>& f_stmts, Redu
|
||||||
auto& stmt_i = stmts[s_i];
|
auto& stmt_i = stmts[s_i];
|
||||||
auto old_stmt = stmt_i;
|
auto old_stmt = stmt_i;
|
||||||
|
|
||||||
auto chain_end = FindAssignmentChain(stmts, s_i);
|
auto chain_end = FindRecAssignmentChain(s_i);
|
||||||
if ( chain_end > s_i && SimplifyChain(stmts, s_i, chain_end - 1, f_stmts) )
|
if ( chain_end > s_i && SimplifyChain(s_i, chain_end - 1, f_stmts) )
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
auto stmt = stmt_i->Reduce(c);
|
auto stmt = stmt_i->Reduce(c);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue