From 8fcf2f5d0e768899dcff16e8b70defe74b50afd3 Mon Sep 17 00:00:00 2001 From: Vern Paxson Date: Thu, 9 May 2024 13:45:04 -0700 Subject: [PATCH] initial framework in place to find chains --- src/script_opt/Expr.cc | 22 ++++++ src/script_opt/Reduce.cc | 24 ++++++- src/script_opt/Stmt.cc | 150 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 192 insertions(+), 4 deletions(-) diff --git a/src/script_opt/Expr.cc b/src/script_opt/Expr.cc index 6988f06354..950457ade1 100644 --- a/src/script_opt/Expr.cc +++ b/src/script_opt/Expr.cc @@ -1801,6 +1801,28 @@ ExprPtr RecordConstructorExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { } StmtPtr RecordConstructorExpr::ReduceToSingletons(Reducer* c) { + auto& exprs2 = op->AsListExpr()->Exprs(); + int nfield = 0; + std::set field_names; + loop_over_list(exprs2, j) { + auto e_i = exprs2[j]; + auto fa_i = e_i->AsFieldAssignExprPtr(); + auto fa_i_rhs = e_i->GetOp1(); + + if ( fa_i_rhs->Tag() == EXPR_FIELD ) { + ++nfield; + auto op1 = fa_i_rhs->GetOp1(); + if ( op1->Tag() == EXPR_NAME ) + field_names.insert(op1->AsNameExpr()->Id()); + } + } + +#if 0 + if ( nfield > 0 ) + printf("constructor with %d fields spanning %lu names: %s\n", + nfield, field_names.size(), obj_desc(this).c_str()); +#endif + StmtPtr red_stmt; auto& exprs = op->AsListExpr()->Exprs(); diff --git a/src/script_opt/Reduce.cc b/src/script_opt/Reduce.cc index 14a9198f26..8afea01029 100644 --- a/src/script_opt/Reduce.cc +++ b/src/script_opt/Reduce.cc @@ -57,7 +57,7 @@ static bool same_op(const Expr* op1, const Expr* op2, bool check_defs) { return def_1 == def_2 && def_1 != NO_DEF; } - else if ( op1->Tag() == EXPR_CONST ) { + if ( op1->Tag() == EXPR_CONST ) { auto op1_c = op1->AsConstExpr(); auto op2_c = op2->AsConstExpr(); @@ -67,7 +67,7 @@ static bool same_op(const Expr* op1, const Expr* op2, bool check_defs) { return same_val(op1_v, op2_v); } - else if ( op1->Tag() == EXPR_LIST ) { + if ( op1->Tag() == EXPR_LIST ) { auto op1_l = op1->AsListExpr()->Exprs(); auto op2_l = op2->AsListExpr()->Exprs(); @@ -81,7 +81,25 @@ static bool same_op(const Expr* op1, const Expr* op2, bool check_defs) { return true; } - reporter->InternalError("bad singleton tag"); + // We only get here if dealing with non-reduced operands. + { + auto subop1_1 = op1->GetOp1(); + auto subop1_2 = op2->GetOp1(); + ASSERT(subop1_1 && subop1_2); + + if ( ! same_expr(subop1_1, subop1_2) ) + return false; + + auto subop2_1 = op1->GetOp2(); + auto subop2_2 = op2->GetOp2(); + if ( subop2_1 && ! same_expr(subop2_1, subop2_2) ) + return false; + + auto subop3_1 = op1->GetOp3(); + auto subop3_2 = op2->GetOp3(); + return ! subop3_1 || same_expr(subop3_1, subop3_2); + } + return false; } diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index e95395e7d4..101ab7ab62 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -143,8 +143,13 @@ StmtPtr ExprStmt::DoReduce(Reducer* c) { // it has a non-void type it'll generate an // assignment to a temporary. red_e_stmt = e->ReduceToSingletons(c); - else + else { e = e->Reduce(c, red_e_stmt); + // It's possible that 'e' has gone away because it was a call + // to an inlined function that doesn't have a return value. + if ( ! e ) + return red_e_stmt; + } if ( red_e_stmt ) { auto s = make_intrusive(red_e_stmt, ThisPtr()); @@ -755,11 +760,154 @@ StmtPtr StmtList::DoReduce(Reducer* c) { return ThisPtr(); } +// Returns just past the end of the chain. +static unsigned int FindAssignmentChain(const std::vector& stmts, unsigned int i) { + const NameExpr* targ_rec = nullptr; + std::set fields_seen; + + for ( ; i < stmts.size(); ++i ) { + auto& s = stmts[i]; + + if ( s->Tag() != STMT_EXPR ) + return i; + + auto se = s->AsExprStmt()->StmtExpr(); + if ( se->Tag() != EXPR_ASSIGN ) + return i; + + auto lhs_ref = se->GetOp1(); + ASSERT(lhs_ref->Tag() == EXPR_REF); + + auto lhs = lhs_ref->GetOp1(); + if ( lhs->Tag() != EXPR_FIELD ) + return i; + + auto lhs_field = lhs->AsFieldExpr()->Field(); + if ( fields_seen.count(lhs_field) > 0 ) + return i; + fields_seen.insert(lhs_field); + + auto lhs_rec = lhs->GetOp1(); + if ( lhs_rec->Tag() != EXPR_NAME ) + // Not a simple field reference. + return i; + + auto lhs_rec_n = lhs_rec->AsNameExpr(); + + if ( targ_rec ) { + if ( lhs_rec_n->Id() != targ_rec->Id() ) + return i; + } + else + targ_rec = lhs_rec_n; + } + + return i; +} + +// Maps RHS identifiers to their collection of operations, expressed +// as the underlying statement. +using OpChain = std::map>; + +static void UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, OpChain& add_chains) { + auto se = s->AsExprStmt()->StmtExpr(); + ASSERT(se->Tag() == EXPR_ASSIGN); + + auto lhs_id = se->GetOp1()->GetOp1()->GetOp1()->AsNameExpr()->Id(); + auto rhs = se->GetOp2(); + const FieldExpr* f; + OpChain* c; + + if ( rhs->Tag() == EXPR_ADD ) { + auto rhs_op1 = rhs->GetOp1(); + + if ( rhs_op1->Tag() != EXPR_FIELD ) + return; + + auto rhs_op1_rec = rhs_op1->GetOp1(); + if ( rhs_op1_rec->Tag() != EXPR_NAME || rhs_op1_rec->AsNameExpr()->Id() != lhs_id ) + return; + + auto rhs_op2 = rhs->GetOp2(); + if ( rhs_op2->Tag() != EXPR_FIELD ) + return; + + f = rhs_op2->AsFieldExpr(); + c = &add_chains; + } + + else if ( rhs->Tag() == EXPR_FIELD ) { + f = rhs->AsFieldExpr(); + c = &assign_chains; + } + + else + return; + + auto f_rec = f->GetOp1(); + if ( f_rec->Tag() != EXPR_NAME ) + return; + + auto id = f_rec->AsNameExpr()->Id(); + auto cf = c->find(id); + if ( cf == c->end() ) + (*c)[id] = std::vector{s.get()}; + else + cf->second.push_back(s.get()); +} + +static void TransformChain(const OpChain& c, ExprTag t, std::set& chain_stmts) { + for ( auto& id_stmts : c ) + for ( auto i_s : id_stmts.second ) { + ASSERT(chain_stmts.count(i_s) > 0); + chain_stmts.erase(i_s); + } +} + +static bool SimplifyChain(const std::vector& stmts, unsigned int start, unsigned int end, + std::vector& f_stmts) { + OpChain assign_chains; + OpChain add_chains; + std::set chain_stmts; + + for ( int i = start; i <= end; ++i ) { + auto& s = stmts[i]; + chain_stmts.insert(s.get()); + UpdateAssignmentChains(s, assign_chains, add_chains); + } + + // An add-chain of any size is a win. For an assign-chain to be + // a win, it needs to have at least two elements. + if ( add_chains.empty() ) { + bool have_useful_assign_chain = false; + for ( auto& ac : assign_chains ) + if ( ac.second.size() > 1 ) { + have_useful_assign_chain = true; + break; + } + + if ( ! have_useful_assign_chain ) + return false; + } + + TransformChain(assign_chains, EXPR_ASSIGN, chain_stmts); + TransformChain(add_chains, EXPR_ADD, chain_stmts); + + printf("chain reduction %d -> %lu starting at %s\n", end - start + 1, chain_stmts.size(), + obj_desc(stmts[start].get()).c_str()); + + return false; +} + bool StmtList::ReduceStmt(unsigned int& s_i, std::vector& f_stmts, Reducer* c) { bool did_change = false; auto& stmt_i = stmts[s_i]; auto old_stmt = stmt_i; + auto chain_end = FindAssignmentChain(stmts, s_i); + if ( chain_end > s_i && SimplifyChain(stmts, s_i, chain_end - 1, f_stmts) ) + return true; + auto stmt = stmt_i->Reduce(c); if ( stmt != old_stmt )