diff --git a/src/Stmt.h b/src/Stmt.h index 5e575f0555..33955bc66d 100644 --- a/src/Stmt.h +++ b/src/Stmt.h @@ -480,6 +480,47 @@ public: protected: bool IsPure() const override; + // These are used for script optimization, to find sequences of + // record assignments that form a chain that can be collapsed into + // specialized expressions/operations. + + // Starting a position i, looks for a chain of record assignments. + // Returns a position just past where the chain ends, so a return + // value of i means "not an assignment chain". + // + // At this point, chains are simply a series of assignment to the + // same record, but different fields, with the only restriction being + // that the record is a simple variable and not a compound like "x$a$b =". + // + // Note that chains can have length 1, which is still useful for + // optimization in some circumstances. + unsigned int FindRecAssignmentChain(unsigned int i) const; + + // For an assignment chain, maps RHS identifiers to their collection + // of operations (all of those that are of a type we know how to + // optimize, and that use the same RHS). The operations are captured + // as the underlying statements, which turns out to be convenient.. + using OpChain = std::map>; + + // For a given statement s that's part of an assignment chain, + // updates its corresponding OpChain, either the one for "x$a = y$b" + // ("assign") or "x$a += y$b" ("add"). Note that for this latter, + // the actual AST is "x$a = x$a + y$b". + void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpChain& add_chains) const; + + // Given an OpChain, transform it into one or more custom expressions + // for evaluating it.The tag t indicates whether this chain is + // a set of assignments or +='s. The statements in the chain should + // all be found in chain_stmts, and will be removed from it. + StmtPtr TransformChain(const OpChain& c, ExprTag t, std::set& chain_stmts) const; + + // Simplify the chain that runs from "start" to "end" by collapsing + // subsets of it into specialized operations. These are added to + // f_stmts, as are the statements in the chain that don't correspond + // to collapsible subsets. Returns true if simplification occurred, + // false if not. + bool SimplifyChain(unsigned int start, unsigned int end, std::vector& f_stmts) const; + std::vector stmts; // Optimization-related: diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index 6c8d067c8e..aefbac637b 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -760,42 +760,50 @@ StmtPtr StmtList::DoReduce(Reducer* c) { return ThisPtr(); } -// Returns just past the end of the chain. -static unsigned int FindAssignmentChain(const std::vector& stmts, unsigned int i) { +unsigned int StmtList::FindRecAssignmentChain(unsigned int i) const { const NameExpr* targ_rec = nullptr; std::set fields_seen; for ( ; i < stmts.size(); ++i ) { auto& s = stmts[i]; + // We're looking for either "x$a = y$b" or "x$a = x$a + y$b". if ( s->Tag() != STMT_EXPR ) + // No way it's an assignment. return i; auto se = s->AsExprStmt()->StmtExpr(); if ( se->Tag() != EXPR_ASSIGN ) return i; + // The LHS of an assignment starts with a RefExpr. auto lhs_ref = se->GetOp1(); ASSERT(lhs_ref->Tag() == EXPR_REF); auto lhs = lhs_ref->GetOp1(); if ( lhs->Tag() != EXPR_FIELD ) + // Not of the form "x$a = ...". return i; auto lhs_field = lhs->AsFieldExpr()->Field(); if ( fields_seen.count(lhs_field) > 0 ) + // Earlier in this chain we've already seen "x$a", so end the + // chain at this repeated use because it's no longer a simple + // block of field assignments. return i; + fields_seen.insert(lhs_field); auto lhs_rec = lhs->GetOp1(); if ( lhs_rec->Tag() != EXPR_NAME ) - // Not a simple field reference. + // Not a simple field reference, e.g. "x$y$a". return i; auto lhs_rec_n = lhs_rec->AsNameExpr(); if ( targ_rec ) { if ( lhs_rec_n->Id() != targ_rec->Id() ) + // It's no longer "x$..." but some new variable "z$...". return i; } else @@ -805,21 +813,24 @@ static unsigned int FindAssignmentChain(const std::vector& stmts, unsig return i; } -// Maps RHS identifiers to their collection of operations, expressed -// as the underlying statement. -using OpChain = std::map>; - -static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpChain& add_chains) { +void StmtList::UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpChain& add_chains) const { auto se = s->AsExprStmt()->StmtExpr(); ASSERT(se->Tag() == EXPR_ASSIGN); + // We dig three times into the LHS. The first gets the EXPR_ASSIGN's + // first operand, which is a RefExpr; the second gets its operand, + // which we've guaranteed in FindRecAssignmentChain is a FieldExpr; + // and the third is the FieldExpr's operand, which we've guaranteed + // is a NameExpr. auto lhs_id = se->GetOp1()->GetOp1()->GetOp1()->AsNameExpr()->Id(); auto rhs = se->GetOp2(); const FieldExpr* f; OpChain* c; + // Check whether RHS is either "y$b" or "x$a + y$b". + if ( rhs->Tag() == EXPR_ADD ) { - auto rhs_op1 = rhs->GetOp1(); + auto rhs_op1 = rhs->GetOp1(); // need to see that it's "x$a" if ( rhs_op1->Tag() != EXPR_FIELD ) return; @@ -828,7 +839,7 @@ static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpC if ( rhs_op1_rec->Tag() != EXPR_NAME || rhs_op1_rec->AsNameExpr()->Id() != lhs_id ) return; - auto rhs_op2 = rhs->GetOp2(); + auto rhs_op2 = rhs->GetOp2(); // need to see that it's "y$b" if ( rhs_op2->Tag() != EXPR_FIELD ) return; @@ -842,12 +853,15 @@ static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpC } else + // Not a RHS we know how to leverage. return; auto f_rec = f->GetOp1(); if ( f_rec->Tag() != EXPR_NAME ) + // Not a simple RHS, instead something like "y$z$b". return; + // If we get here, it's a keeper, record the associated statement. auto id = f_rec->AsNameExpr()->Id(); auto cf = c->find(id); if ( cf == c->end() ) @@ -856,12 +870,17 @@ static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpC cf->second.push_back(s.get()); } -static StmtPtr TransformChain(const OpChain& c, ExprTag t, std::set& chain_stmts, StmtPtr s0) { - auto sl = with_location_of(make_intrusive(), s0); - auto& stmts = sl->Stmts(); +StmtPtr StmtList::TransformChain(const OpChain& c, ExprTag t, std::set& chain_stmts) const { + IntrusivePtr sl; for ( auto& id_stmts : c ) { auto orig_s = id_stmts.second; + + if ( ! sl ) + // Now that we have a statement, create our list and associate + // its location with the statement. + sl = with_location_of(make_intrusive(), orig_s[0]); + ExprPtr e; if ( t == EXPR_ASSIGN ) e = make_intrusive(orig_s, chain_stmts); @@ -870,14 +889,13 @@ static StmtPtr TransformChain(const OpChain& c, ExprTag t, std::set e->SetLocationInfo(sl->GetLocationInfo()); auto es = with_location_of(make_intrusive(std::move(e)), sl); - stmts.emplace_back(std::move(es)); + sl->Stmts().emplace_back(std::move(es)); } return sl; } -static bool SimplifyChain(const std::vector& stmts, unsigned int start, unsigned int end, - std::vector& f_stmts) { +bool StmtList::SimplifyChain(unsigned int start, unsigned int end, std::vector& f_stmts) const { OpChain assign_chains; OpChain add_chains; std::set chain_stmts; @@ -888,8 +906,9 @@ static bool SimplifyChain(const std::vector& stmts, unsigned int start, UpdateAssignmentChains(s, assign_chains, add_chains); } - // An add-chain of any size is a win. For an assign-chain to be - // a win, it needs to have at least two elements. + // An add-chain of any size is a win. For an assign-chain to be a win, + // it needs to have at least two elements, because a single "x$a = y$b" + // can be expressed using one ZAM instructino (but "x$a += y$b" cannot). if ( add_chains.empty() ) { bool have_useful_assign_chain = false; for ( auto& ac : assign_chains ) @@ -899,12 +918,14 @@ static bool SimplifyChain(const std::vector& stmts, unsigned int start, } if ( ! have_useful_assign_chain ) + // No gains available. return false; } - f_stmts.push_back(TransformChain(assign_chains, EXPR_ASSIGN, chain_stmts, stmts[start])); - f_stmts.push_back(TransformChain(add_chains, EXPR_ADD, chain_stmts, stmts[start])); + f_stmts.push_back(TransformChain(assign_chains, EXPR_ASSIGN, chain_stmts)); + f_stmts.push_back(TransformChain(add_chains, EXPR_ADD, chain_stmts)); + // At this point, chain_stmts has only the remainders that weren't removed. for ( auto s : stmts ) if ( chain_stmts.count(s.get()) > 0 ) f_stmts.push_back(s); @@ -917,8 +938,8 @@ bool StmtList::ReduceStmt(unsigned int& s_i, std::vector& f_stmts, Redu auto& stmt_i = stmts[s_i]; auto old_stmt = stmt_i; - auto chain_end = FindAssignmentChain(stmts, s_i); - if ( chain_end > s_i && SimplifyChain(stmts, s_i, chain_end - 1, f_stmts) ) + auto chain_end = FindRecAssignmentChain(s_i); + if ( chain_end > s_i && SimplifyChain(s_i, chain_end - 1, f_stmts) ) return true; auto stmt = stmt_i->Reduce(c);