diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index aefbac637b..6747ceb8c3 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -765,7 +765,7 @@ unsigned int StmtList::FindRecAssignmentChain(unsigned int i) const { std::set fields_seen; for ( ; i < stmts.size(); ++i ) { - auto& s = stmts[i]; + const auto& s = stmts[i]; // We're looking for either "x$a = y$b" or "x$a = x$a + y$b". if ( s->Tag() != STMT_EXPR ) @@ -843,6 +843,10 @@ void StmtList::UpdateAssignmentChains(const StmtPtr& s, OpChain& assign_chains, if ( rhs_op2->Tag() != EXPR_FIELD ) return; + if ( ! IsArithmetic(rhs_op2->GetType()->Tag()) ) + // Avoid esoteric forms of adding. + return; + f = rhs_op2->AsFieldExpr(); c = &add_chains; } @@ -922,8 +926,15 @@ bool StmtList::SimplifyChain(unsigned int start, unsigned int end, std::vector& f_stmts, Redu auto old_stmt = stmt_i; auto chain_end = FindRecAssignmentChain(s_i); - if ( chain_end > s_i && SimplifyChain(s_i, chain_end - 1, f_stmts) ) + if ( chain_end > s_i && SimplifyChain(s_i, chain_end - 1, f_stmts) ) { + s_i = chain_end - 1; return true; + } auto stmt = stmt_i->Reduce(c); diff --git a/src/script_opt/ZAM/Compile.h b/src/script_opt/ZAM/Compile.h index 2ee365e64c..f62eace1a4 100644 --- a/src/script_opt/ZAM/Compile.h +++ b/src/script_opt/ZAM/Compile.h @@ -189,6 +189,7 @@ private: const ZAMStmt CompileAddToExpr(const AddToExpr* e); const ZAMStmt CompileRemoveFromExpr(const RemoveFromExpr* e); const ZAMStmt CompileAssignExpr(const AssignExpr* e); + const ZAMStmt CompileRecFieldUpdates(const RecordFieldUpdates* e); const ZAMStmt CompileZAMBuiltin(const NameExpr* lhs, const ScriptOptBuiltinExpr* zbi); const ZAMStmt CompileAssignToIndex(const NameExpr* lhs, const IndexExpr* rhs); const ZAMStmt CompileFieldLHSAssignExpr(const FieldLHSAssignExpr* e); diff --git a/src/script_opt/ZAM/Expr.cc b/src/script_opt/ZAM/Expr.cc index facc6952ff..aa01b4728b 100644 --- a/src/script_opt/ZAM/Expr.cc +++ b/src/script_opt/ZAM/Expr.cc @@ -22,6 +22,9 @@ const ZAMStmt ZAMCompiler::CompileExpr(const Expr* e) { case EXPR_ASSIGN: return CompileAssignExpr(static_cast(e)); + case EXPR_REC_ASSIGN_FIELDS: + case EXPR_REC_ADD_FIELDS: return CompileRecFieldUpdates(static_cast(e)); + case EXPR_INDEX_ASSIGN: { auto iae = static_cast(e); auto t = iae->GetOp1()->GetType()->Tag(); @@ -227,6 +230,64 @@ const ZAMStmt ZAMCompiler::CompileAssignExpr(const AssignExpr* e) { #include "ZAM-GenExprsDefsV.h" } +const ZAMStmt ZAMCompiler::CompileRecFieldUpdates(const RecordFieldUpdates* e) { + auto lhs = e->GetOp1()->AsNameExpr(); + auto rhs = e->GetOp2()->AsNameExpr(); + + auto& rhs_map = e->RHSMap(); + + auto aux = new ZInstAux(0); + aux->map = e->LHSMap(); + aux->rhs_map = e->RHSMap(); + + std::set field_tags; + + bool is_managed = false; + + for ( auto i : rhs_map ) { + auto rt = rhs->GetType()->AsRecordType(); + auto rt_ft_i = rt->GetFieldType(i); + field_tags.insert(rt_ft_i->Tag()); + + if ( ZVal::IsManagedType(rt_ft_i) ) { + aux->is_managed.push_back(true); + is_managed = true; + } + else + // This will only be needed if is_managed winds up being true, + // but it's harmless to build it up in any case. + aux->is_managed.push_back(false); + + // The following is only needed for non-homogeneous "add"s, but + // likewise it's harmless to build it anyway. + aux->types.push_back(rt_ft_i); + } + + bool homogeneous = field_tags.size() == 1; + if ( ! homogeneous && field_tags.size() == 2 && field_tags.count(TYPE_INT) > 0 && field_tags.count(TYPE_COUNT) > 0 ) + homogeneous = true; + + ZOp op; + + if ( e->Tag() == EXPR_REC_ASSIGN_FIELDS ) + op = is_managed ? OP_REC_ASSIGN_FIELDS_MANAGED_VV : OP_REC_ASSIGN_FIELDS_VV; + + else if ( homogeneous ) { + if ( field_tags.count(TYPE_DOUBLE) > 0 ) + op = OP_REC_ADD_DOUBLE_FIELDS_VV; + else + op = OP_REC_ADD_INT_FIELDS_VV; + } + + else + op = OP_REC_ADD_FIELDS_VV; + + auto z = GenInst(op, lhs, rhs); + z.aux = aux; + + return AddInst(z); +} + const ZAMStmt ZAMCompiler::CompileZAMBuiltin(const NameExpr* lhs, const ScriptOptBuiltinExpr* zbi) { auto op1 = zbi->GetOp1(); auto op2 = zbi->GetOp2(); diff --git a/src/script_opt/ZAM/IterInfo.h b/src/script_opt/ZAM/IterInfo.h index 5a66680cb3..be81e00e4b 100644 --- a/src/script_opt/ZAM/IterInfo.h +++ b/src/script_opt/ZAM/IterInfo.h @@ -51,9 +51,9 @@ public: if ( lv < 0 ) continue; auto& var = frame[lv]; - if ( aux->lvt_is_managed[i] ) + if ( aux->is_managed[i] ) ZVal::DeleteManagedType(var); - auto& t = aux->loop_var_types[i]; + auto& t = aux->types[i]; var = ZVal(ind_lv_p, t); } diff --git a/src/script_opt/ZAM/Ops.in b/src/script_opt/ZAM/Ops.in index 387bd6b382..05d3213b32 100644 --- a/src/script_opt/ZAM/Ops.in +++ b/src/script_opt/ZAM/Ops.in @@ -1268,6 +1268,71 @@ eval auto init_vals = z.aux->ToZValVecWithMap(frame); init_vals[z.v2] = ZVal(run_state::network_time); ConstructRecordPost() +macro SetUpRecFieldOps() + auto lhs = frame[z.v1].record_val; + auto rhs = frame[z.v2].record_val; + auto aux = z.aux; + auto& lhs_map = aux->map; + auto& rhs_map = aux->rhs_map; + auto n = rhs_map.size(); + +op Rec-Assign-Fields +op1-read +type VV +eval SetUpRecFieldOps() + for ( size_t i = 0U; i < n; ++i ) + lhs->RawOptField(lhs_map[i]) = rhs->RawField(rhs_map[i]); + +op Rec-Assign-Fields-Managed +op1-read +type VV +eval SetUpRecFieldOps() + auto is_managed = aux->is_managed; + for ( auto i = 0; i < n; ++i ) + if ( is_managed[i] ) + { + auto& lhs_i = lhs->RawOptField(lhs_map[i]); + auto rhs_i = rhs->RawField(rhs_map[i]); + zeek::Ref(rhs_i.ManagedVal()); + if ( lhs_i ) + ZVal::DeleteManagedType(*lhs_i); + lhs_i = rhs_i; + } + else + lhs->RawOptField(lhs_map[i]) = rhs->RawField(rhs_map[i]); + +op Rec-Add-Int-Fields +op1-read +type VV +eval SetUpRecFieldOps() + for ( size_t i = 0U; i < n; ++i ) + lhs->RawField(lhs_map[i]).int_val += rhs->RawField(rhs_map[i]).int_val; + +op Rec-Add-Double-Fields +op1-read +type VV +eval SetUpRecFieldOps() + for ( size_t i = 0U; i < n; ++i ) + lhs->RawField(lhs_map[i]).double_val += rhs->RawField(rhs_map[i]).double_val; + +op Rec-Add-Fields +op1-read +type VV +eval SetUpRecFieldOps() + auto& types = aux->types; + for ( size_t i = 0U; i < n; ++i ) + { + auto& lhs_i = lhs->RawField(lhs_map[i]); + auto rhs_i = rhs->RawField(rhs_map[i]); + auto tag = types[i]->Tag(); + if ( tag == TYPE_INT ) + lhs_i.int_val += rhs_i.int_val; + else if ( tag == TYPE_COUNT ) + lhs_i.uint_val += rhs_i.uint_val; + else + lhs_i.double_val += rhs_i.double_val; + } + # Special instruction for concretizing vectors that are fields in a # newly-constructed record. "aux" holds which fields in the record to # inspect. diff --git a/src/script_opt/ZAM/Stmt.cc b/src/script_opt/ZAM/Stmt.cc index 425062157d..04dd18f41c 100644 --- a/src/script_opt/ZAM/Stmt.cc +++ b/src/script_opt/ZAM/Stmt.cc @@ -765,8 +765,8 @@ const ZAMStmt ZAMCompiler::LoopOverTable(const ForStmt* f, const NameExpr* val) int slot = id->IsBlank() ? -1 : FrameSlot(id); aux->loop_vars.push_back(slot); auto& t = id->GetType(); - aux->loop_var_types.push_back(t); - aux->lvt_is_managed.push_back(ZVal::IsManagedType(t)); + aux->types.push_back(t); + aux->is_managed.push_back(ZVal::IsManagedType(t)); } bool no_loop_vars = (num_unused == loop_vars->length()); diff --git a/src/script_opt/ZAM/ZInst.h b/src/script_opt/ZAM/ZInst.h index 8bde9bc571..69c2f6452f 100644 --- a/src/script_opt/ZAM/ZInst.h +++ b/src/script_opt/ZAM/ZInst.h @@ -484,20 +484,27 @@ public: // store here. bool can_change_non_locals = false; - // The following is used for constructing records, to map elements in - // slots/constants/types to record field offsets. + // The following is used for constructing records or in record chain + // operations, to map elements in slots/constants/types to record field + // offsets. std::vector map; + // The following is used when we need two maps, a LHS one (done with + // the above) and a RHS one. + std::vector rhs_map; + + // For operations that need to track types corresponding to other vectors. + std::vector types; + + // For operations that mix managed and unmanaged assignments. + std::vector is_managed; + ///// The following four apply to looping over the elements of tables. // Frame slots of iteration variables, such as "[v1, v2, v3] in aggr". // A negative value means "skip assignment". std::vector loop_vars; - // Their types and whether they're managed. - std::vector loop_var_types; - std::vector lvt_is_managed; - // Type associated with the "value" entry, for "k, value in aggr" // iteration. TypePtr value_var_type;