diff --git a/src/Expr.cc b/src/Expr.cc index caa7d0e78a..2dc6656de9 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -25,6 +25,7 @@ #include "zeek/digest.h" #include "zeek/module_util.h" #include "zeek/script_opt/ExprOptInfo.h" +#include "zeek/script_opt/ScriptOpt.h" namespace zeek::detail { diff --git a/src/Func.cc b/src/Func.cc index 008aa29c5e..4ef2990934 100644 --- a/src/Func.cc +++ b/src/Func.cc @@ -138,6 +138,13 @@ void Func::AddBody(detail::StmtPtr new_body, const std::vector& n AddBody(new_body, new_inits, new_frame_size, priority, groups); } +void Func::AddBody(detail::StmtPtr new_body, size_t new_frame_size) + { + std::vector no_inits; + std::set no_groups; + AddBody(new_body, no_inits, new_frame_size, 0, no_groups); + } + void Func::AddBody(detail::StmtPtr /* new_body */, const std::vector& /* new_inits */, size_t /* new_frame_size */, int /* priority */, const std::set& /* groups */) @@ -288,10 +295,12 @@ void Func::CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFl namespace detail { -ScriptFunc::ScriptFunc(const IDPtr& arg_id) : Func(SCRIPT_FUNC) +ScriptFunc::ScriptFunc(const IDPtr& id) : ScriptFunc::ScriptFunc(id.get()) { } + +ScriptFunc::ScriptFunc(const ID* id) : Func(SCRIPT_FUNC) { - name = arg_id->Name(); - type = arg_id->GetType(); + name = id->Name(); + type = id->GetType(); frame_size = 0; } @@ -658,11 +667,32 @@ void ScriptFunc::ReplaceBody(const StmtPtr& old_body, StmtPtr new_body) bool ScriptFunc::DeserializeCaptures(const broker::vector& data) { - auto result = Frame::Unserialize(data, GetType()->GetCaptures()); + auto result = Frame::Unserialize(data); ASSERT(result.first); - SetCaptures(result.second.release()); + auto& f = result.second; + + if ( bodies[0].stmts->Tag() == STMT_ZAM ) + { + auto& captures = *type->GetCaptures(); + int n = f->FrameSize(); + + ASSERT(captures.size() == n); + + auto cvec = std::make_unique>(); + + for ( int i = 0; i < n; ++i ) + { + auto& f_i = f->GetElement(i); + cvec->push_back(ZVal(f_i, captures[i].Id()->GetType())); + } + + CreateCaptures(std::move(cvec)); + } + + else + SetCaptures(f.release()); return true; } diff --git a/src/Func.h b/src/Func.h index 53f1e6c19a..7d6c6a1212 100644 --- a/src/Func.h +++ b/src/Func.h @@ -172,6 +172,7 @@ class ScriptFunc : public Func { public: ScriptFunc(const IDPtr& id); + ScriptFunc(const ID* id); // For compiled scripts. ScriptFunc(std::string name, FuncTypePtr ft, std::vector bodies, diff --git a/src/script_opt/UseDefs.cc b/src/script_opt/UseDefs.cc index 8e54ab18b7..50996ef9a9 100644 --- a/src/script_opt/UseDefs.cc +++ b/src/script_opt/UseDefs.cc @@ -17,10 +17,11 @@ void UseDefSet::Dump() const printf(" %s", u->Name()); } -UseDefs::UseDefs(StmtPtr _body, std::shared_ptr _rc) +UseDefs::UseDefs(StmtPtr _body, std::shared_ptr _rc, FuncTypePtr _ft) { body = std::move(_body); rc = std::move(_rc); + ft = std::move(_ft); } void UseDefs::Analyze() @@ -164,6 +165,13 @@ bool UseDefs::CheckIfUnused(const Stmt* s, const ID* id, bool report) if ( id->IsGlobal() ) return false; + if ( auto& captures = ft->GetCaptures() ) + { + for ( auto& c : *captures ) + if ( c.Id() == id ) + return false; + } + auto uds = FindSuccUsage(s); if ( ! uds || ! uds->HasID(id) ) { @@ -283,9 +291,7 @@ UDs UseDefs::PropagateUDs(const Stmt* s, UDs succ_UDs, const Stmt* succ_stmt, bo auto true_UDs = PropagateUDs(i->TrueBranch(), succ_UDs, succ_stmt, second_pass); auto false_UDs = PropagateUDs(i->FalseBranch(), succ_UDs, succ_stmt, second_pass); - auto uds = CreateUDs(s, UD_Union(cond_UDs, true_UDs, false_UDs)); - - return uds; + return CreateUDs(s, UD_Union(cond_UDs, true_UDs, false_UDs)); } case STMT_INIT: @@ -450,6 +456,7 @@ UDs UseDefs::ExprUDs(const Expr* e) switch ( e->Tag() ) { case EXPR_NAME: + case EXPR_LAMBDA: AddInExprUDs(uds, e); break; @@ -482,19 +489,23 @@ UDs UseDefs::ExprUDs(const Expr* e) break; } - case EXPR_CONST: - break; - - case EXPR_LAMBDA: + case EXPR_TABLE_CONSTRUCTOR: { - auto l = static_cast(e); - auto ids = l->OuterIDs(); + auto t = static_cast(e); + AddInExprUDs(uds, t->GetOp1().get()); + + auto& t_attrs = t->GetAttrs(); + auto def_attr = t_attrs ? t_attrs->Find(ATTR_DEFAULT) : nullptr; + auto& def_expr = def_attr ? def_attr->GetExpr() : nullptr; + if ( def_expr && def_expr->Tag() == EXPR_LAMBDA ) + uds = ExprUDs(def_expr.get()); - for ( const auto& id : ids ) - AddID(uds, id); break; } + case EXPR_CONST: + break; + case EXPR_CALL: { auto c = e->AsCallExpr(); @@ -577,6 +588,14 @@ void UseDefs::AddInExprUDs(UDs uds, const Expr* e) AddInExprUDs(uds, e->AsFieldExpr()->Op()); break; + case EXPR_LAMBDA: + { + auto outer_ids = e->AsLambdaExpr()->OuterIDs(); + for ( auto& i : outer_ids ) + AddID(uds, i); + break; + } + case EXPR_CONST: // Nothing to do. break; diff --git a/src/script_opt/UseDefs.h b/src/script_opt/UseDefs.h index 03bdcc4735..4b53776005 100644 --- a/src/script_opt/UseDefs.h +++ b/src/script_opt/UseDefs.h @@ -51,7 +51,7 @@ class Reducer; class UseDefs { public: - UseDefs(StmtPtr body, std::shared_ptr rc); + UseDefs(StmtPtr body, std::shared_ptr rc, FuncTypePtr ft); // Does a full pass over the function body's AST. We can wind // up doing this multiple times because when we use use-defs to @@ -173,6 +173,7 @@ private: StmtPtr body; std::shared_ptr rc; + FuncTypePtr ft; }; } // zeek::detail diff --git a/src/script_opt/ZAM/AM-Opt.cc b/src/script_opt/ZAM/AM-Opt.cc index 7735167fc4..145d71ab74 100644 --- a/src/script_opt/ZAM/AM-Opt.cc +++ b/src/script_opt/ZAM/AM-Opt.cc @@ -258,9 +258,9 @@ bool ZAMCompiler::PruneUnused() KillInst(i); } - if ( inst->IsGlobalLoad() ) + if ( inst->IsNonLocalLoad() ) { - // Any straight-line load of the same global + // Any straight-line load of the same global/capture // is redundant. for ( unsigned int j = i + 1; j < insts1.size(); ++j ) { @@ -277,14 +277,14 @@ bool ZAMCompiler::PruneUnused() // Inbound branch ends block. break; - if ( i1->aux && i1->aux->can_change_globals ) + if ( i1->aux && i1->aux->can_change_non_locals ) break; - if ( ! i1->IsGlobalLoad() ) + if ( ! i1->IsNonLocalLoad() ) continue; - if ( i1->v2 == inst->v2 ) - { // Same global + if ( i1->v2 == inst->v2 && i1->IsGlobalLoad() == inst->IsGlobalLoad() ) + { // Same global/capture did_prune = true; KillInst(i1); } @@ -299,9 +299,10 @@ bool ZAMCompiler::PruneUnused() // Variable is used, keep assignment. continue; - if ( frame_denizens[slot]->IsGlobal() ) + auto& id = frame_denizens[slot]; + if ( id->IsGlobal() || IsCapture(id) ) { - // Extend the global's range to the end of the + // Extend the global/capture's range to the end of the // function. denizen_ending[slot] = insts1.back(); continue; @@ -466,18 +467,30 @@ void ZAMCompiler::ComputeFrameLifetimes() break; } + case OP_LAMBDA_VV: + { + auto aux = inst->aux; + int n = aux->n; + auto& slots = aux->slots; + for ( int i = 0; i < n; ++i ) + ExtendLifetime(slots[i], EndOfLoop(inst, 1)); + break; + } + default: // Look for slots in auxiliary information. auto aux = inst->aux; if ( ! aux || ! aux->slots ) break; - for ( auto j = 0; j < aux->n; ++j ) + int n = aux->n; + auto& slots = aux->slots; + for ( auto j = 0; j < n; ++j ) { - if ( aux->slots[j] < 0 ) + if ( slots[j] < 0 ) continue; - ExtendLifetime(aux->slots[j], EndOfLoop(inst, 1)); + ExtendLifetime(slots[j], EndOfLoop(inst, 1)); } break; } @@ -759,7 +772,6 @@ void ZAMCompiler::ReMapVar(const ID* id, int slot, zeek_uint_t inst) void ZAMCompiler::CheckSlotAssignment(int slot, const ZInstI* inst) { ASSERT(slot >= 0 && static_cast(slot) < frame_denizens.size()); - // We construct temporaries such that their values are never used // earlier than their definitions in loop bodies. For other // denizens, however, they can be, so in those cases we expand the diff --git a/src/script_opt/ZAM/Compile.h b/src/script_opt/ZAM/Compile.h index 1287dcdc6a..e6bce25b9e 100644 --- a/src/script_opt/ZAM/Compile.h +++ b/src/script_opt/ZAM/Compile.h @@ -92,6 +92,7 @@ private: void Init(); void InitGlobals(); void InitArgs(); + void InitCaptures(); void InitLocals(); void TrackMemoryManagement(); @@ -350,8 +351,15 @@ private: bool IsUnused(const IDPtr& id, const Stmt* where) const; + bool IsCapture(const IDPtr& id) const { return IsCapture(id.get()); } + bool IsCapture(const ID* id) const; + + int CaptureOffset(const IDPtr& id) const { return IsCapture(id.get()); } + int CaptureOffset(const ID* id) const; + void LoadParam(const ID* id); const ZAMStmt LoadGlobal(const ID* id); + const ZAMStmt LoadCapture(const ID* id); int AddToFrame(const ID*); @@ -599,8 +607,10 @@ private: // Used for communication between Frame1Slot and a subsequent // AddInst. If >= 0, then upon adding the next instruction, - // it should be followed by Store-Global for the given slot. + // it should be followed by Store-Global or Store-Capture for + // the given slot. int pending_global_store = -1; + int pending_capture_store = -1; }; // Invokes after compiling all of the function bodies. diff --git a/src/script_opt/ZAM/Driver.cc b/src/script_opt/ZAM/Driver.cc index 18426909da..c89abf0bfd 100644 --- a/src/script_opt/ZAM/Driver.cc +++ b/src/script_opt/ZAM/Driver.cc @@ -92,12 +92,25 @@ void ZAMCompiler::InitArgs() pop_scope(); } +void ZAMCompiler::InitCaptures() + { + for ( auto c : pf->Captures() ) + (void)AddToFrame(c); + } + void ZAMCompiler::InitLocals() { // Assign slots for locals (which includes temporaries). for ( auto l : pf->Locals() ) { + if ( IsCapture(l) ) + continue; + + if ( pf->WhenLocals().count(l) > 0 ) + continue; + auto non_const_l = const_cast(l); + // Don't add locals that were already added because they're // parameters. // diff --git a/src/script_opt/ZAM/Expr.cc b/src/script_opt/ZAM/Expr.cc index cbf29900fa..ebcd3bda6b 100644 --- a/src/script_opt/ZAM/Expr.cc +++ b/src/script_opt/ZAM/Expr.cc @@ -4,6 +4,7 @@ #include "zeek/Desc.h" #include "zeek/Reporter.h" +#include "zeek/script_opt/ProfileFunc.h" #include "zeek/script_opt/ZAM/Compile.h" namespace zeek::detail @@ -176,12 +177,6 @@ const ZAMStmt ZAMCompiler::CompileAssignExpr(const AssignExpr* e) auto r2 = rhs->GetOp2(); auto r3 = rhs->GetOp3(); - if ( rhs->Tag() == EXPR_LAMBDA ) - { - // reporter->Error("lambda expressions not supported for compiling"); - return ErrorStmt(); - } - if ( rhs->Tag() == EXPR_NAME ) return AssignVV(lhs, rhs->AsNameExpr()); @@ -213,6 +208,9 @@ const ZAMStmt ZAMCompiler::CompileAssignExpr(const AssignExpr* e) if ( rhs->Tag() == EXPR_ANY_INDEX ) return AnyIndexVVi(lhs, r1->AsNameExpr(), rhs->AsAnyIndexExpr()->Index()); + if ( rhs->Tag() == EXPR_LAMBDA ) + return BuildLambda(lhs, rhs->AsLambdaExpr()); + if ( rhs->Tag() == EXPR_COND && r1->GetType()->Tag() == TYPE_VECTOR ) return Bool_Vec_CondVVVV(lhs, r1->AsNameExpr(), r2->AsNameExpr(), r3->AsNameExpr()); @@ -747,6 +745,38 @@ const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, int n2_slot, const T return AddInst(z); } +const ZAMStmt ZAMCompiler::BuildLambda(const NameExpr* n, LambdaExpr* le) + { + return BuildLambda(Frame1Slot(n, OP1_WRITE), le); + } + +const ZAMStmt ZAMCompiler::BuildLambda(int n_slot, LambdaExpr* le) + { + auto& captures = le->GetCaptures(); + int ncaptures = captures ? captures->size() : 0; + + auto aux = new ZInstAux(ncaptures); + aux->master_func = le->MasterFunc(); + aux->lambda_name = le->Name(); + aux->id_val = le->Ingredients()->GetID().get(); + + for ( int i = 0; i < ncaptures; ++i ) + { + auto& id_i = (*captures)[i].Id(); + + if ( pf->WhenLocals().count(id_i.get()) > 0 ) + aux->Add(i, nullptr); + else + aux->Add(i, FrameSlot(id_i), id_i->GetType()); + } + + auto z = ZInstI(OP_LAMBDA_VV, n_slot, le->MasterFunc()->FrameSize()); + z.op_type = OP_VV_I2; + z.aux = aux; + + return AddInst(z); + } + const ZAMStmt ZAMCompiler::AssignVecElems(const Expr* e) { auto index_assign = e->AsIndexAssignExpr(); @@ -1062,6 +1092,31 @@ const ZAMStmt ZAMCompiler::ConstructTable(const NameExpr* n, const Expr* e) z.t = tt; z.attrs = e->AsTableConstructorExpr()->GetAttrs(); + auto zstmt = AddInst(z); + + auto def_attr = z.attrs ? z.attrs->Find(ATTR_DEFAULT) : nullptr; + if ( ! def_attr || def_attr->GetExpr()->Tag() != EXPR_LAMBDA ) + return zstmt; + + auto def_lambda = def_attr->GetExpr()->AsLambdaExpr(); + auto dl_t = def_lambda->GetType()->AsFuncType(); + auto& captures = dl_t->GetCaptures(); + + if ( ! captures ) + return zstmt; + + // What a pain. The table's default value is a lambda that has + // captures. The semantics of this are that the captures are + // evaluated at table-construction time. We need to build the + // lambda and assign it as the table's default. + + auto slot = NewSlot(true); // since func_val's are managed + (void)BuildLambda(slot, def_lambda); + + z = GenInst(OP_SET_TABLE_DEFAULT_LAMBDA_VV, n, slot); + z.op_type = OP_VV; + z.t = def_lambda->GetType(); + return AddInst(z); } diff --git a/src/script_opt/ZAM/Low-Level.cc b/src/script_opt/ZAM/Low-Level.cc index 40fd4c30df..0db09a0147 100644 --- a/src/script_opt/ZAM/Low-Level.cc +++ b/src/script_opt/ZAM/Low-Level.cc @@ -149,6 +149,9 @@ const ZAMStmt ZAMCompiler::AddInst(const ZInstI& inst, bool suppress_non_local) if ( suppress_non_local ) return ZAMStmt(top_main_inst); + // Ensure we haven't confused ourselves about any pending stores. + ASSERT(pending_global_store == -1 || pending_capture_store == -1); + if ( pending_global_store >= 0 ) { auto gs = pending_global_store; @@ -161,6 +164,27 @@ const ZAMStmt ZAMCompiler::AddInst(const ZInstI& inst, bool suppress_non_local) return AddInst(store_inst); } + if ( pending_capture_store >= 0 ) + { + auto cs = pending_capture_store; + pending_capture_store = -1; + + auto& cv = *func->GetType()->AsFuncType()->GetCaptures(); + auto& c_id = cv[cs].Id(); + + ZOp op; + + if ( ZVal::IsManagedType(c_id->GetType()) ) + op = OP_STORE_MANAGED_CAPTURE_VV; + else + op = OP_STORE_CAPTURE_VV; + + auto store_inst = ZInstI(op, RawSlot(c_id.get()), cs); + store_inst.op_type = OP_VV_I2; + + return AddInst(store_inst); + } + return ZAMStmt(top_main_inst); } diff --git a/src/script_opt/ZAM/Ops.in b/src/script_opt/ZAM/Ops.in index 7350622cbb..f628e85f05 100644 --- a/src/script_opt/ZAM/Ops.in +++ b/src/script_opt/ZAM/Ops.in @@ -1092,6 +1092,15 @@ eval ConstructTableOrSetPre() } ConstructTableOrSetPost() +# When tables are constructed, if their &default is a lambda with captures +# then we need to explicitly set up the default. +internal-op Set-Table-Default-Lambda +type VV +op1-read +eval auto& tbl = frame[z.v1].table_val; + auto lambda = frame[z.v2].ToVal(z.t); + tbl->InitDefaultVal(lambda); + direct-unary-op Set-Constructor ConstructSet internal-op Construct-Set @@ -2007,12 +2016,41 @@ eval auto& v = frame[z.v1].type_val; auto t = globals[z.v2].id->GetType(); v = new TypeVal(t, true); +internal-op Load-Capture +type VV +eval frame[z.v1] = f->GetFunction()->GetCapturesVec()[z.v2]; + +internal-op Load-Managed-Capture +type VV +eval auto& lhs = frame[z.v1]; + auto& rhs = f->GetFunction()->GetCapturesVec()[z.v2]; + zeek::Ref(rhs.ManagedVal()); + ZVal::DeleteManagedType(lhs); + lhs = rhs; + internal-op Store-Global op1-internal type V eval auto& g = globals[z.v1]; g.id->SetVal(frame[g.slot].ToVal(z.t)); +# Both of these have the LHS as v2 not v1, to keep with existing +# conventions of OP_VV_I2 op type (as opposed to OP_VV_I1_V2, which doesn't +# currently exist, and would be a pain to add). +internal-op Store-Capture +op1-read +type VV +eval f->GetFunction()->GetCapturesVec()[z.v2] = frame[z.v1]; + +internal-op Store-Managed-Capture +op1-read +type VV +eval auto& lhs = f->GetFunction()->GetCapturesVec()[z.v2]; + auto& rhs = frame[z.v1]; + zeek::Ref(rhs.ManagedVal()); + ZVal::DeleteManagedType(lhs); + lhs = rhs; + internal-op Copy-To type VC @@ -2029,6 +2067,37 @@ eval flow = FLOW_BREAK; pc = end_pc; continue; +# Slot 2 gives frame size. +internal-op Lambda +type VV +eval auto& aux = z.aux; + auto& master_func = aux->master_func; + auto& body = master_func->GetBodies()[0].stmts; + ASSERT(body->Tag() == STMT_ZAM); + auto lamb = make_intrusive(aux->id_val); + lamb->AddBody(body, z.v2); + lamb->SetName(aux->lambda_name.c_str()); + if ( aux->n > 0 ) + { + auto captures = std::make_unique>(); + for ( auto i = 0; i < aux->n; ++i ) + { + auto slot = aux->slots[i]; + if ( slot >= 0 ) + { + auto& cp = frame[aux->slots[i]]; + if ( aux->is_managed[i] ) + zeek::Ref(cp.ManagedVal()); + captures->push_back(cp); + } + else + // Used for when-locals. + captures->push_back(ZVal()); + } + lamb->CreateCaptures(std::move(captures)); + } + ZVal::DeleteManagedType(frame[z.v1]); + frame[z.v1].func_val = lamb.release(); ######################################## # Built-in Functions