ZAM support for lambdas

This commit is contained in:
Vern Paxson 2023-06-16 16:08:54 -07:00 committed by Arne Welzel
parent 0a40aec4a6
commit 7d5760ac74
11 changed files with 272 additions and 37 deletions

View file

@ -25,6 +25,7 @@
#include "zeek/digest.h" #include "zeek/digest.h"
#include "zeek/module_util.h" #include "zeek/module_util.h"
#include "zeek/script_opt/ExprOptInfo.h" #include "zeek/script_opt/ExprOptInfo.h"
#include "zeek/script_opt/ScriptOpt.h"
namespace zeek::detail namespace zeek::detail
{ {

View file

@ -138,6 +138,13 @@ void Func::AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& n
AddBody(new_body, new_inits, new_frame_size, priority, groups); AddBody(new_body, new_inits, new_frame_size, priority, groups);
} }
void Func::AddBody(detail::StmtPtr new_body, size_t new_frame_size)
{
std::vector<detail::IDPtr> no_inits;
std::set<EventGroupPtr> no_groups;
AddBody(new_body, no_inits, new_frame_size, 0, no_groups);
}
void Func::AddBody(detail::StmtPtr /* new_body */, void Func::AddBody(detail::StmtPtr /* new_body */,
const std::vector<detail::IDPtr>& /* new_inits */, size_t /* new_frame_size */, const std::vector<detail::IDPtr>& /* new_inits */, size_t /* new_frame_size */,
int /* priority */, const std::set<EventGroupPtr>& /* groups */) int /* priority */, const std::set<EventGroupPtr>& /* groups */)
@ -288,10 +295,12 @@ void Func::CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFl
namespace detail namespace detail
{ {
ScriptFunc::ScriptFunc(const IDPtr& arg_id) : Func(SCRIPT_FUNC) ScriptFunc::ScriptFunc(const IDPtr& id) : ScriptFunc::ScriptFunc(id.get()) { }
ScriptFunc::ScriptFunc(const ID* id) : Func(SCRIPT_FUNC)
{ {
name = arg_id->Name(); name = id->Name();
type = arg_id->GetType<zeek::FuncType>(); type = id->GetType<zeek::FuncType>();
frame_size = 0; frame_size = 0;
} }
@ -658,11 +667,32 @@ void ScriptFunc::ReplaceBody(const StmtPtr& old_body, StmtPtr new_body)
bool ScriptFunc::DeserializeCaptures(const broker::vector& data) bool ScriptFunc::DeserializeCaptures(const broker::vector& data)
{ {
auto result = Frame::Unserialize(data, GetType()->GetCaptures()); auto result = Frame::Unserialize(data);
ASSERT(result.first); ASSERT(result.first);
SetCaptures(result.second.release()); auto& f = result.second;
if ( bodies[0].stmts->Tag() == STMT_ZAM )
{
auto& captures = *type->GetCaptures();
int n = f->FrameSize();
ASSERT(captures.size() == n);
auto cvec = std::make_unique<std::vector<ZVal>>();
for ( int i = 0; i < n; ++i )
{
auto& f_i = f->GetElement(i);
cvec->push_back(ZVal(f_i, captures[i].Id()->GetType()));
}
CreateCaptures(std::move(cvec));
}
else
SetCaptures(f.release());
return true; return true;
} }

View file

@ -172,6 +172,7 @@ class ScriptFunc : public Func
{ {
public: public:
ScriptFunc(const IDPtr& id); ScriptFunc(const IDPtr& id);
ScriptFunc(const ID* id);
// For compiled scripts. // For compiled scripts.
ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies, ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies,

View file

@ -17,10 +17,11 @@ void UseDefSet::Dump() const
printf(" %s", u->Name()); printf(" %s", u->Name());
} }
UseDefs::UseDefs(StmtPtr _body, std::shared_ptr<Reducer> _rc) UseDefs::UseDefs(StmtPtr _body, std::shared_ptr<Reducer> _rc, FuncTypePtr _ft)
{ {
body = std::move(_body); body = std::move(_body);
rc = std::move(_rc); rc = std::move(_rc);
ft = std::move(_ft);
} }
void UseDefs::Analyze() void UseDefs::Analyze()
@ -164,6 +165,13 @@ bool UseDefs::CheckIfUnused(const Stmt* s, const ID* id, bool report)
if ( id->IsGlobal() ) if ( id->IsGlobal() )
return false; return false;
if ( auto& captures = ft->GetCaptures() )
{
for ( auto& c : *captures )
if ( c.Id() == id )
return false;
}
auto uds = FindSuccUsage(s); auto uds = FindSuccUsage(s);
if ( ! uds || ! uds->HasID(id) ) if ( ! uds || ! uds->HasID(id) )
{ {
@ -283,9 +291,7 @@ UDs UseDefs::PropagateUDs(const Stmt* s, UDs succ_UDs, const Stmt* succ_stmt, bo
auto true_UDs = PropagateUDs(i->TrueBranch(), succ_UDs, succ_stmt, second_pass); auto true_UDs = PropagateUDs(i->TrueBranch(), succ_UDs, succ_stmt, second_pass);
auto false_UDs = PropagateUDs(i->FalseBranch(), succ_UDs, succ_stmt, second_pass); auto false_UDs = PropagateUDs(i->FalseBranch(), succ_UDs, succ_stmt, second_pass);
auto uds = CreateUDs(s, UD_Union(cond_UDs, true_UDs, false_UDs)); return CreateUDs(s, UD_Union(cond_UDs, true_UDs, false_UDs));
return uds;
} }
case STMT_INIT: case STMT_INIT:
@ -450,6 +456,7 @@ UDs UseDefs::ExprUDs(const Expr* e)
switch ( e->Tag() ) switch ( e->Tag() )
{ {
case EXPR_NAME: case EXPR_NAME:
case EXPR_LAMBDA:
AddInExprUDs(uds, e); AddInExprUDs(uds, e);
break; break;
@ -482,19 +489,23 @@ UDs UseDefs::ExprUDs(const Expr* e)
break; break;
} }
case EXPR_CONST: case EXPR_TABLE_CONSTRUCTOR:
break;
case EXPR_LAMBDA:
{ {
auto l = static_cast<const LambdaExpr*>(e); auto t = static_cast<const TableConstructorExpr*>(e);
auto ids = l->OuterIDs(); AddInExprUDs(uds, t->GetOp1().get());
auto& t_attrs = t->GetAttrs();
auto def_attr = t_attrs ? t_attrs->Find(ATTR_DEFAULT) : nullptr;
auto& def_expr = def_attr ? def_attr->GetExpr() : nullptr;
if ( def_expr && def_expr->Tag() == EXPR_LAMBDA )
uds = ExprUDs(def_expr.get());
for ( const auto& id : ids )
AddID(uds, id);
break; break;
} }
case EXPR_CONST:
break;
case EXPR_CALL: case EXPR_CALL:
{ {
auto c = e->AsCallExpr(); auto c = e->AsCallExpr();
@ -577,6 +588,14 @@ void UseDefs::AddInExprUDs(UDs uds, const Expr* e)
AddInExprUDs(uds, e->AsFieldExpr()->Op()); AddInExprUDs(uds, e->AsFieldExpr()->Op());
break; break;
case EXPR_LAMBDA:
{
auto outer_ids = e->AsLambdaExpr()->OuterIDs();
for ( auto& i : outer_ids )
AddID(uds, i);
break;
}
case EXPR_CONST: case EXPR_CONST:
// Nothing to do. // Nothing to do.
break; break;

View file

@ -51,7 +51,7 @@ class Reducer;
class UseDefs class UseDefs
{ {
public: public:
UseDefs(StmtPtr body, std::shared_ptr<Reducer> rc); UseDefs(StmtPtr body, std::shared_ptr<Reducer> rc, FuncTypePtr ft);
// Does a full pass over the function body's AST. We can wind // Does a full pass over the function body's AST. We can wind
// up doing this multiple times because when we use use-defs to // up doing this multiple times because when we use use-defs to
@ -173,6 +173,7 @@ private:
StmtPtr body; StmtPtr body;
std::shared_ptr<Reducer> rc; std::shared_ptr<Reducer> rc;
FuncTypePtr ft;
}; };
} // zeek::detail } // zeek::detail

View file

@ -258,9 +258,9 @@ bool ZAMCompiler::PruneUnused()
KillInst(i); KillInst(i);
} }
if ( inst->IsGlobalLoad() ) if ( inst->IsNonLocalLoad() )
{ {
// Any straight-line load of the same global // Any straight-line load of the same global/capture
// is redundant. // is redundant.
for ( unsigned int j = i + 1; j < insts1.size(); ++j ) for ( unsigned int j = i + 1; j < insts1.size(); ++j )
{ {
@ -277,14 +277,14 @@ bool ZAMCompiler::PruneUnused()
// Inbound branch ends block. // Inbound branch ends block.
break; break;
if ( i1->aux && i1->aux->can_change_globals ) if ( i1->aux && i1->aux->can_change_non_locals )
break; break;
if ( ! i1->IsGlobalLoad() ) if ( ! i1->IsNonLocalLoad() )
continue; continue;
if ( i1->v2 == inst->v2 ) if ( i1->v2 == inst->v2 && i1->IsGlobalLoad() == inst->IsGlobalLoad() )
{ // Same global { // Same global/capture
did_prune = true; did_prune = true;
KillInst(i1); KillInst(i1);
} }
@ -299,9 +299,10 @@ bool ZAMCompiler::PruneUnused()
// Variable is used, keep assignment. // Variable is used, keep assignment.
continue; continue;
if ( frame_denizens[slot]->IsGlobal() ) auto& id = frame_denizens[slot];
if ( id->IsGlobal() || IsCapture(id) )
{ {
// Extend the global's range to the end of the // Extend the global/capture's range to the end of the
// function. // function.
denizen_ending[slot] = insts1.back(); denizen_ending[slot] = insts1.back();
continue; continue;
@ -466,18 +467,30 @@ void ZAMCompiler::ComputeFrameLifetimes()
break; break;
} }
case OP_LAMBDA_VV:
{
auto aux = inst->aux;
int n = aux->n;
auto& slots = aux->slots;
for ( int i = 0; i < n; ++i )
ExtendLifetime(slots[i], EndOfLoop(inst, 1));
break;
}
default: default:
// Look for slots in auxiliary information. // Look for slots in auxiliary information.
auto aux = inst->aux; auto aux = inst->aux;
if ( ! aux || ! aux->slots ) if ( ! aux || ! aux->slots )
break; break;
for ( auto j = 0; j < aux->n; ++j ) int n = aux->n;
auto& slots = aux->slots;
for ( auto j = 0; j < n; ++j )
{ {
if ( aux->slots[j] < 0 ) if ( slots[j] < 0 )
continue; continue;
ExtendLifetime(aux->slots[j], EndOfLoop(inst, 1)); ExtendLifetime(slots[j], EndOfLoop(inst, 1));
} }
break; break;
} }
@ -759,7 +772,6 @@ void ZAMCompiler::ReMapVar(const ID* id, int slot, zeek_uint_t inst)
void ZAMCompiler::CheckSlotAssignment(int slot, const ZInstI* inst) void ZAMCompiler::CheckSlotAssignment(int slot, const ZInstI* inst)
{ {
ASSERT(slot >= 0 && static_cast<zeek_uint_t>(slot) < frame_denizens.size()); ASSERT(slot >= 0 && static_cast<zeek_uint_t>(slot) < frame_denizens.size());
// We construct temporaries such that their values are never used // We construct temporaries such that their values are never used
// earlier than their definitions in loop bodies. For other // earlier than their definitions in loop bodies. For other
// denizens, however, they can be, so in those cases we expand the // denizens, however, they can be, so in those cases we expand the

View file

@ -92,6 +92,7 @@ private:
void Init(); void Init();
void InitGlobals(); void InitGlobals();
void InitArgs(); void InitArgs();
void InitCaptures();
void InitLocals(); void InitLocals();
void TrackMemoryManagement(); void TrackMemoryManagement();
@ -350,8 +351,15 @@ private:
bool IsUnused(const IDPtr& id, const Stmt* where) const; bool IsUnused(const IDPtr& id, const Stmt* where) const;
bool IsCapture(const IDPtr& id) const { return IsCapture(id.get()); }
bool IsCapture(const ID* id) const;
int CaptureOffset(const IDPtr& id) const { return IsCapture(id.get()); }
int CaptureOffset(const ID* id) const;
void LoadParam(const ID* id); void LoadParam(const ID* id);
const ZAMStmt LoadGlobal(const ID* id); const ZAMStmt LoadGlobal(const ID* id);
const ZAMStmt LoadCapture(const ID* id);
int AddToFrame(const ID*); int AddToFrame(const ID*);
@ -599,8 +607,10 @@ private:
// Used for communication between Frame1Slot and a subsequent // Used for communication between Frame1Slot and a subsequent
// AddInst. If >= 0, then upon adding the next instruction, // AddInst. If >= 0, then upon adding the next instruction,
// it should be followed by Store-Global for the given slot. // it should be followed by Store-Global or Store-Capture for
// the given slot.
int pending_global_store = -1; int pending_global_store = -1;
int pending_capture_store = -1;
}; };
// Invokes after compiling all of the function bodies. // Invokes after compiling all of the function bodies.

View file

@ -92,12 +92,25 @@ void ZAMCompiler::InitArgs()
pop_scope(); pop_scope();
} }
void ZAMCompiler::InitCaptures()
{
for ( auto c : pf->Captures() )
(void)AddToFrame(c);
}
void ZAMCompiler::InitLocals() void ZAMCompiler::InitLocals()
{ {
// Assign slots for locals (which includes temporaries). // Assign slots for locals (which includes temporaries).
for ( auto l : pf->Locals() ) for ( auto l : pf->Locals() )
{ {
if ( IsCapture(l) )
continue;
if ( pf->WhenLocals().count(l) > 0 )
continue;
auto non_const_l = const_cast<ID*>(l); auto non_const_l = const_cast<ID*>(l);
// Don't add locals that were already added because they're // Don't add locals that were already added because they're
// parameters. // parameters.
// //

View file

@ -4,6 +4,7 @@
#include "zeek/Desc.h" #include "zeek/Desc.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/script_opt/ProfileFunc.h"
#include "zeek/script_opt/ZAM/Compile.h" #include "zeek/script_opt/ZAM/Compile.h"
namespace zeek::detail namespace zeek::detail
@ -176,12 +177,6 @@ const ZAMStmt ZAMCompiler::CompileAssignExpr(const AssignExpr* e)
auto r2 = rhs->GetOp2(); auto r2 = rhs->GetOp2();
auto r3 = rhs->GetOp3(); auto r3 = rhs->GetOp3();
if ( rhs->Tag() == EXPR_LAMBDA )
{
// reporter->Error("lambda expressions not supported for compiling");
return ErrorStmt();
}
if ( rhs->Tag() == EXPR_NAME ) if ( rhs->Tag() == EXPR_NAME )
return AssignVV(lhs, rhs->AsNameExpr()); return AssignVV(lhs, rhs->AsNameExpr());
@ -213,6 +208,9 @@ const ZAMStmt ZAMCompiler::CompileAssignExpr(const AssignExpr* e)
if ( rhs->Tag() == EXPR_ANY_INDEX ) if ( rhs->Tag() == EXPR_ANY_INDEX )
return AnyIndexVVi(lhs, r1->AsNameExpr(), rhs->AsAnyIndexExpr()->Index()); return AnyIndexVVi(lhs, r1->AsNameExpr(), rhs->AsAnyIndexExpr()->Index());
if ( rhs->Tag() == EXPR_LAMBDA )
return BuildLambda(lhs, rhs->AsLambdaExpr());
if ( rhs->Tag() == EXPR_COND && r1->GetType()->Tag() == TYPE_VECTOR ) if ( rhs->Tag() == EXPR_COND && r1->GetType()->Tag() == TYPE_VECTOR )
return Bool_Vec_CondVVVV(lhs, r1->AsNameExpr(), r2->AsNameExpr(), r3->AsNameExpr()); return Bool_Vec_CondVVVV(lhs, r1->AsNameExpr(), r2->AsNameExpr(), r3->AsNameExpr());
@ -747,6 +745,38 @@ const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, int n2_slot, const T
return AddInst(z); return AddInst(z);
} }
const ZAMStmt ZAMCompiler::BuildLambda(const NameExpr* n, LambdaExpr* le)
{
return BuildLambda(Frame1Slot(n, OP1_WRITE), le);
}
const ZAMStmt ZAMCompiler::BuildLambda(int n_slot, LambdaExpr* le)
{
auto& captures = le->GetCaptures();
int ncaptures = captures ? captures->size() : 0;
auto aux = new ZInstAux(ncaptures);
aux->master_func = le->MasterFunc();
aux->lambda_name = le->Name();
aux->id_val = le->Ingredients()->GetID().get();
for ( int i = 0; i < ncaptures; ++i )
{
auto& id_i = (*captures)[i].Id();
if ( pf->WhenLocals().count(id_i.get()) > 0 )
aux->Add(i, nullptr);
else
aux->Add(i, FrameSlot(id_i), id_i->GetType());
}
auto z = ZInstI(OP_LAMBDA_VV, n_slot, le->MasterFunc()->FrameSize());
z.op_type = OP_VV_I2;
z.aux = aux;
return AddInst(z);
}
const ZAMStmt ZAMCompiler::AssignVecElems(const Expr* e) const ZAMStmt ZAMCompiler::AssignVecElems(const Expr* e)
{ {
auto index_assign = e->AsIndexAssignExpr(); auto index_assign = e->AsIndexAssignExpr();
@ -1062,6 +1092,31 @@ const ZAMStmt ZAMCompiler::ConstructTable(const NameExpr* n, const Expr* e)
z.t = tt; z.t = tt;
z.attrs = e->AsTableConstructorExpr()->GetAttrs(); z.attrs = e->AsTableConstructorExpr()->GetAttrs();
auto zstmt = AddInst(z);
auto def_attr = z.attrs ? z.attrs->Find(ATTR_DEFAULT) : nullptr;
if ( ! def_attr || def_attr->GetExpr()->Tag() != EXPR_LAMBDA )
return zstmt;
auto def_lambda = def_attr->GetExpr()->AsLambdaExpr();
auto dl_t = def_lambda->GetType()->AsFuncType();
auto& captures = dl_t->GetCaptures();
if ( ! captures )
return zstmt;
// What a pain. The table's default value is a lambda that has
// captures. The semantics of this are that the captures are
// evaluated at table-construction time. We need to build the
// lambda and assign it as the table's default.
auto slot = NewSlot(true); // since func_val's are managed
(void)BuildLambda(slot, def_lambda);
z = GenInst(OP_SET_TABLE_DEFAULT_LAMBDA_VV, n, slot);
z.op_type = OP_VV;
z.t = def_lambda->GetType();
return AddInst(z); return AddInst(z);
} }

View file

@ -149,6 +149,9 @@ const ZAMStmt ZAMCompiler::AddInst(const ZInstI& inst, bool suppress_non_local)
if ( suppress_non_local ) if ( suppress_non_local )
return ZAMStmt(top_main_inst); return ZAMStmt(top_main_inst);
// Ensure we haven't confused ourselves about any pending stores.
ASSERT(pending_global_store == -1 || pending_capture_store == -1);
if ( pending_global_store >= 0 ) if ( pending_global_store >= 0 )
{ {
auto gs = pending_global_store; auto gs = pending_global_store;
@ -161,6 +164,27 @@ const ZAMStmt ZAMCompiler::AddInst(const ZInstI& inst, bool suppress_non_local)
return AddInst(store_inst); return AddInst(store_inst);
} }
if ( pending_capture_store >= 0 )
{
auto cs = pending_capture_store;
pending_capture_store = -1;
auto& cv = *func->GetType()->AsFuncType()->GetCaptures();
auto& c_id = cv[cs].Id();
ZOp op;
if ( ZVal::IsManagedType(c_id->GetType()) )
op = OP_STORE_MANAGED_CAPTURE_VV;
else
op = OP_STORE_CAPTURE_VV;
auto store_inst = ZInstI(op, RawSlot(c_id.get()), cs);
store_inst.op_type = OP_VV_I2;
return AddInst(store_inst);
}
return ZAMStmt(top_main_inst); return ZAMStmt(top_main_inst);
} }

View file

@ -1092,6 +1092,15 @@ eval ConstructTableOrSetPre()
} }
ConstructTableOrSetPost() ConstructTableOrSetPost()
# When tables are constructed, if their &default is a lambda with captures
# then we need to explicitly set up the default.
internal-op Set-Table-Default-Lambda
type VV
op1-read
eval auto& tbl = frame[z.v1].table_val;
auto lambda = frame[z.v2].ToVal(z.t);
tbl->InitDefaultVal(lambda);
direct-unary-op Set-Constructor ConstructSet direct-unary-op Set-Constructor ConstructSet
internal-op Construct-Set internal-op Construct-Set
@ -2007,12 +2016,41 @@ eval auto& v = frame[z.v1].type_val;
auto t = globals[z.v2].id->GetType(); auto t = globals[z.v2].id->GetType();
v = new TypeVal(t, true); v = new TypeVal(t, true);
internal-op Load-Capture
type VV
eval frame[z.v1] = f->GetFunction()->GetCapturesVec()[z.v2];
internal-op Load-Managed-Capture
type VV
eval auto& lhs = frame[z.v1];
auto& rhs = f->GetFunction()->GetCapturesVec()[z.v2];
zeek::Ref(rhs.ManagedVal());
ZVal::DeleteManagedType(lhs);
lhs = rhs;
internal-op Store-Global internal-op Store-Global
op1-internal op1-internal
type V type V
eval auto& g = globals[z.v1]; eval auto& g = globals[z.v1];
g.id->SetVal(frame[g.slot].ToVal(z.t)); g.id->SetVal(frame[g.slot].ToVal(z.t));
# Both of these have the LHS as v2 not v1, to keep with existing
# conventions of OP_VV_I2 op type (as opposed to OP_VV_I1_V2, which doesn't
# currently exist, and would be a pain to add).
internal-op Store-Capture
op1-read
type VV
eval f->GetFunction()->GetCapturesVec()[z.v2] = frame[z.v1];
internal-op Store-Managed-Capture
op1-read
type VV
eval auto& lhs = f->GetFunction()->GetCapturesVec()[z.v2];
auto& rhs = frame[z.v1];
zeek::Ref(rhs.ManagedVal());
ZVal::DeleteManagedType(lhs);
lhs = rhs;
internal-op Copy-To internal-op Copy-To
type VC type VC
@ -2029,6 +2067,37 @@ eval flow = FLOW_BREAK;
pc = end_pc; pc = end_pc;
continue; continue;
# Slot 2 gives frame size.
internal-op Lambda
type VV
eval auto& aux = z.aux;
auto& master_func = aux->master_func;
auto& body = master_func->GetBodies()[0].stmts;
ASSERT(body->Tag() == STMT_ZAM);
auto lamb = make_intrusive<ScriptFunc>(aux->id_val);
lamb->AddBody(body, z.v2);
lamb->SetName(aux->lambda_name.c_str());
if ( aux->n > 0 )
{
auto captures = std::make_unique<std::vector<ZVal>>();
for ( auto i = 0; i < aux->n; ++i )
{
auto slot = aux->slots[i];
if ( slot >= 0 )
{
auto& cp = frame[aux->slots[i]];
if ( aux->is_managed[i] )
zeek::Ref(cp.ManagedVal());
captures->push_back(cp);
}
else
// Used for when-locals.
captures->push_back(ZVal());
}
lamb->CreateCaptures(std::move(captures));
}
ZVal::DeleteManagedType(frame[z.v1]);
frame[z.v1].func_val = lamb.release();
######################################## ########################################
# Built-in Functions # Built-in Functions