From a3001f1b2b31e263d4cd22d12dd1780cb9ef0e06 Mon Sep 17 00:00:00 2001 From: Zeke Medley Date: Wed, 12 Jun 2019 14:40:40 -0700 Subject: [PATCH] Add lambda expressions with closures to Zeek. This allows anonymous functions in Zeek to capture their closures. they do so by creating a copy of their enclosing frame and joining that with their own frame. There is no way to specify what specific items to capture from the closure like C++, nor is there a nonlocal keyword like Python. Attemptying to declare a local variable that has already been caught by the closure will error nicely. At the worst this is an inconvenience for people who are using lambdas which use the same variable names as their closures. As a result of functions copying their enclosing frames there is no way for a function with a closure to reach back up and modify the state of the frame that it was created in. This lets functions that generate functions work as expected. The function can reach back and modify its copy of the frame that it is captured in though. Implementation wise this is done by creating two new subclasses in Zeek. The first is a LambdaExpression which can be thought of as a function generator. It gathers all of the ingredients for a function at parse time, and then when evaluated creats a new version of that function with the frame it is being evaluated in as a closure. The second subclass is a ClosureFrame. This acts for most intents and purposes like a regular Frame, but it routes lookups of values to its closure as needed. --- src/Expr.cc | 51 ++++++- src/Expr.h | 34 ++++- src/Frame.cc | 142 ++++++++++++++++++ src/Frame.h | 79 +++++++++- src/Func.cc | 95 +++++++++++- src/Func.h | 40 +++++ src/Obj.h | 2 +- src/Stmt.cc | 12 +- src/Val.cc | 2 +- src/Val.h | 1 + src/Var.cc | 81 +++++++--- src/Var.h | 6 + src/parse.y | 36 ++++- .../Baseline/language.function-closures/out | 22 +++ .../Baseline/language.outer_param_binding/out | 4 +- testing/btest/language/function-closures.zeek | 77 ++++++++++ .../btest/language/outer_param_binding.zeek | 4 + 17 files changed, 636 insertions(+), 52 deletions(-) create mode 100644 testing/btest/Baseline/language.function-closures/out create mode 100644 testing/btest/language/function-closures.zeek diff --git a/src/Expr.cc b/src/Expr.cc index 1ec357e945..2b1b56fffb 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -282,7 +282,7 @@ Val* NameExpr::Eval(Frame* f) const v = id->ID_Val(); else if ( f ) - v = f->NthElement(id->Offset()); + v = f->GetElement(id); else // No frame - evaluating for Simplify() purposes @@ -316,7 +316,7 @@ void NameExpr::Assign(Frame* f, Val* v, Opcode op) if ( id->IsGlobal() ) id->SetVal(v, op); else - f->SetElement(id->Offset(), v); + f->SetElement(id, v); } int NameExpr::IsPure() const @@ -4472,6 +4472,7 @@ FlattenExpr::FlattenExpr(Expr* arg_op) SetType(tl); } + Val* FlattenExpr::Fold(Val* v) const { RecordVal* rv = v->AsRecordVal(); @@ -4991,6 +4992,52 @@ bool CallExpr::DoUnserialize(UnserialInfo* info) return args != 0; } +LambdaExpr::LambdaExpr(std::unique_ptr ingredients, + std::shared_ptr outer_ids) + { + this->ingredients = std::move(ingredients); + this->outer_ids = std::move(outer_ids); + SetType(this->ingredients->id->Type()->Ref()); + } + +Val* LambdaExpr::Eval(Frame* f) const + { + BroFunc* lamb = new BroFunc( + ingredients->id, + ingredients->body, + ingredients->inits, + ingredients->frame_size, + ingredients->priority); + + lamb->AddClosure(outer_ids, f); + + ingredients->id->SetVal((new Val(lamb))->Ref()); + ingredients->id->SetConst(); + ingredients->id->ID_Val()->AsFunc()->SetScope(ingredients->scope); + + return ingredients->id->ID_Val(); + } + +void LambdaExpr::ExprDescribe(ODesc* d) const + { + d->Add("Lambda Expression"); + d->Add("{"); + ingredients->body->Describe(d); + d->Add("}"); + } + +TraversalCode LambdaExpr::Traverse(TraversalCallback* cb) const + { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = ingredients->body->Traverse(cb); + HANDLE_TC_EXPR_POST(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); + } + EventExpr::EventExpr(const char* arg_name, ListExpr* arg_args) : Expr(EXPR_EVENT) { diff --git a/src/Expr.h b/src/Expr.h index e268f07648..6843c04318 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -12,6 +12,10 @@ #include "Debug.h" #include "EventHandler.h" #include "TraverseTypes.h" +#include "Func.h" // function_ingredients + +#include // std::shared_ptr +#include // std::move typedef enum { EXPR_ANY = -1, @@ -62,6 +66,8 @@ class AssignExpr; class CallExpr; class EventExpr; +struct function_ingredients; + class Expr : public BroObj { public: @@ -997,7 +1003,7 @@ public: protected: friend class Expr; - CallExpr() { func = 0; args = 0; } + CallExpr() { func = 0; args = 0; } void ExprDescribe(ODesc* d) const override; @@ -1007,6 +1013,32 @@ protected: ListExpr* args; }; +/* + Class to handle the creation of anonymous functions with closures. + + Facts: + - LambdaExpr creates a new BroFunc on every call to Eval. + - LambdaExpr must be given all the information to create a BroFunc on + construction except for the closure. + - The closure for created BroFuncs is the frame that the LambdaExpr is + evaluated in. +*/ +class LambdaExpr : public Expr { +public: + LambdaExpr(std::unique_ptr ingredients, + std::shared_ptr outer_ids); + + Val* Eval(Frame* f) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; + +protected: + void ExprDescribe(ODesc* d) const override; + +private: + std::unique_ptr ingredients; + std::shared_ptr outer_ids; +}; + class EventExpr : public Expr { public: EventExpr(const char* name, ListExpr* args); diff --git a/src/Frame.cc b/src/Frame.cc index d065fb440a..3640964829 100644 --- a/src/Frame.cc +++ b/src/Frame.cc @@ -7,6 +7,8 @@ #include "Func.h" #include "Trigger.h" +#include + vector g_frame_stack; Frame::Frame(int arg_size, const BroFunc* func, const val_list* fn_args) @@ -27,12 +29,54 @@ Frame::Frame(int arg_size, const BroFunc* func, const val_list* fn_args) Clear(); } +Frame::Frame(const Frame* other) + { + this->size = other->size; + this->frame = other->frame; + this->function = other->function; + this->func_args = other->func_args; + + this->next_stmt = 0; + this->break_before_next_stmt = false; + this->break_on_return = false; + this->delayed = false; + + // We need to Ref this because the + // destructor will Unref. + if ( other->trigger ) + Ref(other->trigger); + + this->trigger = other->trigger; + this->call = other->call; + } + Frame::~Frame() { Unref(trigger); Release(); } +void Frame::SetElement(int n, Val* v) + { + Unref(frame[n]); + frame[n] = v; + } +void Frame::SetElement(const ID* id, Val* v) + { + SetElement(id->Offset(), v); + } + + +Val* Frame::GetElement(ID* id) const + { + return this->frame[id->Offset()]; + } + +void Frame::AddElement(ID* id, Val* v) + { + this->SetElement(id, v); + } + void Frame::Reset(int startIdx) { for ( int i = startIdx; i < size; ++i ) @@ -109,3 +153,101 @@ void Frame::ClearTrigger() Unref(trigger); trigger = 0; } + +ClosureFrame::ClosureFrame(Frame* closure, Frame* not_closure, + std::shared_ptr outer_ids) : Frame(not_closure) + { + assert(closure); + + this->closure = closure; + Ref(this->closure); + this->body = not_closure; + Ref(this->body); + + // To clone a ClosureFrame we null outer_ids and then copy + // the set over directly, hence the check. + if (outer_ids) + { + // Install the closure IDs + id_list* tmp = outer_ids.get(); + loop_over_list(*tmp, i) + { + ID* id = (*tmp)[i]; + if (id) + this->closure_elements.insert(id->Name()); + } + } + } + +ClosureFrame::~ClosureFrame() + { + Unref(this->closure); + Unref(this->body); + } + +Val* ClosureFrame::GetElement(ID* id) const + { + if (this->closure_elements.find(id->Name()) != this->closure_elements.end()) + return ClosureFrame::GatherFromClosure(this, id); + + return this->NthElement(id->Offset()); + } + +void ClosureFrame::SetElement(const ID* id, Val* v) + { + if (this->closure_elements.find(id->Name()) != this->closure_elements.end()) + ClosureFrame::SetInClosure(this, id, v); + else + this->Frame::SetElement(id->Offset(), v); + } + +Frame* ClosureFrame::Clone() + { + Frame* new_closure = this->closure->Clone(); + Frame* new_regular = this->body->Clone(); + + ClosureFrame* cf = new ClosureFrame(new_closure, new_regular, nullptr); + cf->closure_elements = this->closure_elements; + return cf; + } + + +// Each ClosureFrame knows all of the outer IDs that are used inside of it. This is known at +// parse time. These leverage that. If frame_1 encloses frame_2 then the location of a lookup +// for an outer id in frame_2 can be determined by checking if that id is also an outer id in +// frame_2. If it is not, then frame_2 owns the id and the lookup is done there, otherwise, +// go deeper. + +// Note the useage of dynamic_cast. + + +Val* ClosureFrame::GatherFromClosure(const Frame* start, const ID* id) + { + const ClosureFrame* conductor = dynamic_cast(start); + + auto closure_contains = [] (const ClosureFrame* cf, const ID* i) + { return cf->closure_elements.find(i->Name()) != cf->closure_elements.end(); }; + + if ( ! conductor ) + return start->NthElement(id->Offset()); + + if (closure_contains(conductor, id)) + return ClosureFrame::GatherFromClosure(conductor->closure, id); + + return conductor->NthElement(id->Offset()); + } + +void ClosureFrame::SetInClosure(Frame* start, const ID* id, Val* val) + { + ClosureFrame* conductor = dynamic_cast(start); + + auto closure_contains = [] (const ClosureFrame* cf, const ID* i) + { return cf->closure_elements.find(i->Name()) != cf->closure_elements.end(); }; + + if ( ! conductor ) + start->SetElement(id->Offset(), val); + else if (closure_contains(conductor, id)) + ClosureFrame::SetInClosure(conductor->closure, id, val); + else + conductor->Frame::SetElement(id->Offset(), val); + } diff --git a/src/Frame.h b/src/Frame.h index 1469543e10..93d9677844 100644 --- a/src/Frame.h +++ b/src/Frame.h @@ -4,25 +4,31 @@ #define frame_h #include -using namespace std; +#include +#include #include "Val.h" +using namespace std; + class BroFunc; class Trigger; class CallExpr; +class Val; class Frame : public BroObj { public: Frame(int size, const BroFunc* func, const val_list *fn_args); + // The constructed frame becomes a view of the input frame. No copying is done. + Frame(const Frame* other); ~Frame() override; - Val* NthElement(int n) { return frame[n]; } - void SetElement(int n, Val* v) - { - Unref(frame[n]); - frame[n] = v; - } + Val* NthElement(int n) const { return frame[n]; } + void SetElement(int n, Val* v); + virtual void SetElement(const ID* id, Val* v); + + virtual Val* GetElement(ID* id) const; + void AddElement(ID* id, Val* v); void Reset(int startIdx); void Release(); @@ -49,7 +55,9 @@ public: bool BreakOnReturn() const { return break_on_return; } // Deep-copies values. - Frame* Clone(); + virtual Frame* Clone(); + // Only deep-copies values corresponding to requested IDs. + Frame* SelectiveClone(id_list* selection); // If the frame is run in the context of a trigger condition evaluation, // the trigger needs to be registered. @@ -82,6 +90,61 @@ protected: bool delayed; }; +/* +Class that allows for lookups in both a closure frame and a regular frame +according to a list of IDs passed into the constructor. + +Facts: + - A ClosureFrame is created from two frames: a closure and a regular frame. + - ALL operations except GetElement operations operate on the regular frame. + - A ClosureFrame requires a list of outside ID's captured by the closure. + - Get operations on those IDs will be performed on the closure frame. + +ClosureFrame allows functions that generate functions to be passed between +different sized frames and still properly capture their closures. It also allows for +cleaner handling of closures. +*/ +class ClosureFrame : public Frame { +public: + ClosureFrame(Frame* closure, Frame* body, + std::shared_ptr outer_ids); + ~ClosureFrame() override; + Val* GetElement(ID* id) const override; + void SetElement(const ID* id, Val* v) override; + Frame* Clone() override; + +private: + Frame* closure; + Frame* body; + + // Searches this frame and all sub-frame's closures for a value corresponding + // to the id. + static Val* GatherFromClosure(const Frame* start, const ID* id); + // Moves through the closure frames and associates val with id. + static void SetInClosure(Frame* start, const ID* id, Val* val); + + // Hashes c style strings. The strings need to be null-terminated. + struct const_char_hasher { + size_t operator()(const char* in) const + { + // http://www.cse.yorku.ca/~oz/hash.html + size_t h = 5381; + int c; + + while ((c = *in++)) + h = ((h << 5) + h) + c; + + return h; + } + }; + + // NOTE: In a perfect world this would be best done with a trie or bloom + // filter. We only need to check if things are NOT in the closure. + // In reality though the size of a closure is small enough that operatons are + // fairly quick anyway. + std::unordered_set closure_elements; +}; + extern vector g_frame_stack; #endif diff --git a/src/Func.cc b/src/Func.cc index 90515a0f8f..16debd0fb8 100644 --- a/src/Func.cc +++ b/src/Func.cc @@ -232,6 +232,15 @@ bool Func::DoUnserialize(UnserialInfo* info) return true; } +Val* Func::DoClone() + { + // By default, ok just to return a reference. Func does not have any "state". + // That is different across instances. + Val* v = new Val(this); + Ref(v); + return v; + } + void Func::DescribeDebug(ODesc* d, const val_list* args) const { d->Add(Name()); @@ -369,6 +378,8 @@ BroFunc::~BroFunc() { for ( unsigned int i = 0; i < bodies.size(); ++i ) Unref(bodies[i].stmts); + + Unref(this->closure); } int BroFunc::IsPure() const @@ -411,7 +422,13 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const return Flavor() == FUNC_FLAVOR_HOOK ? val_mgr->GetTrue() : 0; } + // f will hold the closure & function's values Frame* f = new Frame(frame_size, this, args); + if (this->closure) + { + assert(outer_ids); + f = new ClosureFrame(this->closure, f, this->outer_ids); + } // Hand down any trigger. if ( parent ) @@ -439,19 +456,18 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const for ( size_t i = 0; i < bodies.size(); ++i ) { if ( sample_logger ) - sample_logger->LocationSeen( - bodies[i].stmts->GetLocationInfo()); + sample_logger->LocationSeen(bodies[i].stmts->GetLocationInfo()); Unref(result); + // Fill in the rest of the frame with the function's arguments. loop_over_list(*args, j) { Val* arg = (*args)[j]; if ( f->NthElement(j) != arg ) { - // Either not yet set, or somebody reassigned - // the frame slot. + // Either not yet set, or somebody reassigned the frame slot. Ref(arg); f->SetElement(j, arg); } @@ -468,14 +484,18 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const { // Already reported, but now determine whether to unwind further. if ( Flavor() == FUNC_FLAVOR_FUNCTION ) + { + Unref(f); + Unref(result); throw; + } // Continue exec'ing remaining bodies of hooks/events. continue; } if ( f->HasDelayed() ) - { + { assert(! result); assert(parent); parent->SetDelayed(); @@ -517,7 +537,7 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const (flow != FLOW_RETURN /* we fell off the end */ || ! result /* explicit return with no result */) && ! f->HasDelayed() ) - reporter->Warning("non-void function returns without a value: %s", + reporter->Warning("non-void function returning without a value: %s", Name()); if ( result && g_trace_state.DoTrace() ) @@ -529,6 +549,7 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const } g_frame_stack.pop_back(); + Unref(f); return result; @@ -559,6 +580,65 @@ void BroFunc::AddBody(Stmt* new_body, id_list* new_inits, int new_frame_size, sort(bodies.begin(), bodies.end()); } +void BroFunc::AddClosure(std::shared_ptr ids, Frame* f) + { + // Order matters here. + this->SetOuterIDs(ids); + this->SetClosureFrame(f); + } + +void BroFunc::SetClosureFrame(Frame* f) + { + if (this->closure) + reporter->InternalError + ("Tried to override closure for BroFunc %s.", this->Name()); + + this->closure = f ? f->Clone() : nullptr; + } + +void BroFunc::ShiftOffsets(int shift, std::shared_ptr idl) + { + id_list* tmp = idl.get(); + if (! idl || shift == 0) + { + // Nothing to do here. + return; + } + + loop_over_list(*tmp, i) + { + ID* id = (*tmp)[i]; + id->SetOffset(id->Offset() + shift); + } + } + +Val* BroFunc::DoClone() + { + // A BroFunc could hold a closure. In this case a clone of it must copy this + // store a copy of this closure. + if ( ! this->closure ) + { + return Func::DoClone(); + } + else + { + BroFunc* other = new BroFunc(); + + other->bodies = this->bodies; + other->scope = this->scope; + other->kind = this->kind; + other->type = this->type; + other->name = this->name; + other->unique_id = this->unique_id; + other->unique_ids = this->unique_ids; + other->frame_size = this->frame_size; + other->closure = this->closure->Clone(); + other->outer_ids = this->outer_ids; + + return new Val(other); + } + } + void BroFunc::Describe(ODesc* d) const { d->Add(Name()); @@ -578,7 +658,8 @@ Stmt* BroFunc::AddInits(Stmt* body, id_list* inits) return body; StmtList* stmt_series = new StmtList; - stmt_series->Stmts().append(new InitStmt(inits)); + InitStmt* first = new InitStmt(inits); + stmt_series->Stmts().append(first); stmt_series->Stmts().append(body); return stmt_series; diff --git a/src/Func.h b/src/Func.h index 48e0c2e8b8..b5d195eb34 100644 --- a/src/Func.h +++ b/src/Func.h @@ -4,10 +4,13 @@ #define func_h #include +#include // std::shared_ptr, std::unique_ptr #include "BroList.h" #include "Obj.h" #include "Debug.h" +#include "Frame.h" +// #include "Val.h" class Val; class ListExpr; @@ -17,6 +20,8 @@ class Frame; class ID; class CallExpr; +struct CloneState; + class Func : public BroObj { public: @@ -62,6 +67,7 @@ public: // This (un-)serializes only a single body (as given in SerialInfo). bool Serialize(SerialInfo* info) const; static Func* Unserialize(UnserialInfo* info); + virtual Val* DoClone(); virtual TraversalCode Traverse(TraversalCallback* cb) const; @@ -95,9 +101,14 @@ public: int IsPure() const override; Val* Call(val_list* args, Frame* parent) const override; + void AddClosure(std::shared_ptr ids, Frame* f); void AddBody(Stmt* new_body, id_list* new_inits, int new_frame_size, int priority) override; + void fsets(); + + Val* DoClone() override; + int FrameSize() const { return frame_size; } void Describe(ODesc* d) const override; @@ -109,6 +120,23 @@ protected: DECLARE_SERIAL(BroFunc); int frame_size; + +private: + // Shifts the offsets of each id in "idl" by "shift". + static void ShiftOffsets(int shift, std::shared_ptr idl); + + // Makes a deep copy of the input frame and captures it. + void SetClosureFrame(Frame* f); + + void SetOuterIDs(std::shared_ptr ids) + { outer_ids = std::move(ids); } + + // List of the outer IDs used in the function. Shared becase other instances + // would like to use it as well. + std::shared_ptr outer_ids = nullptr; + // The frame the Func was initialized in. This is not guaranteed to be + // initialized and should be handled with care. + Frame* closure = nullptr; }; typedef Val* (*built_in_func)(Frame* frame, val_list* args); @@ -146,6 +174,18 @@ struct CallInfo { const val_list* args; }; +// Struct that collects the arguments for a Func. +// Used for BroFuncs with closures. +struct function_ingredients + { + ID* id; + Stmt* body; + id_list* inits; + int frame_size; + int priority; + Scope* scope; + }; + extern vector call_stack; extern std::string render_call_stack(); diff --git a/src/Obj.h b/src/Obj.h index 21730ff367..b1ac9a110d 100644 --- a/src/Obj.h +++ b/src/Obj.h @@ -207,7 +207,7 @@ inline void Error(const BroObj* o, const char* msg) inline void Ref(BroObj* o) { - if ( ++o->ref_cnt <= 1 ) + if ( ++(o->ref_cnt) <= 1 ) bad_ref(0); if ( o->ref_cnt == INT_MAX ) bad_ref(1); diff --git a/src/Stmt.cc b/src/Stmt.cc index 5960747d05..4e5b2d5050 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -960,7 +960,7 @@ Val* SwitchStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const if ( matching_id ) { auto cv = cast_value_to_type(v, matching_id->Type()); - f->SetElement(matching_id->Offset(), cv); + f->SetElement(matching_id, cv); } flow = FLOW_NEXT; @@ -1477,10 +1477,10 @@ Val* ForStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const delete k; if ( value_var ) - f->SetElement(value_var->Offset(), current_tev->Value()->Ref()); + f->SetElement(value_var, current_tev->Value()->Ref()); for ( int i = 0; i < ind_lv->Length(); i++ ) - f->SetElement((*loop_vars)[i]->Offset(), ind_lv->Index(i)->Ref()); + f->SetElement((*loop_vars)[i], ind_lv->Index(i)->Ref()); Unref(ind_lv); flow = FLOW_NEXT; @@ -1508,7 +1508,7 @@ Val* ForStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const // Set the loop variable to the current index, and make // another pass over the loop body. - f->SetElement((*loop_vars)[0]->Offset(), + f->SetElement((*loop_vars)[0], val_mgr->GetCount(i)); flow = FLOW_NEXT; ret = body->Exec(f, flow); @@ -1523,7 +1523,7 @@ Val* ForStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const for ( int i = 0; i < sval->Len(); ++i ) { - f->SetElement((*loop_vars)[0]->Offset(), + f->SetElement((*loop_vars)[0], new StringVal(1, (const char*) sval->Bytes() + i)); flow = FLOW_NEXT; ret = body->Exec(f, flow); @@ -2084,7 +2084,7 @@ Val* InitStmt::Exec(Frame* f, stmt_flow_type& flow) const break; } - f->SetElement(aggr->Offset(), v); + f->SetElement(aggr, v); } return 0; diff --git a/src/Val.cc b/src/Val.cc index bb9a3d1601..77c375da3a 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -103,7 +103,7 @@ Val* Val::DoClone(CloneState* state) // Functions and files. There aren't any derived classes. if ( type->Tag() == TYPE_FUNC ) // Immutable. - return Ref(); + return this->AsFunc()->DoClone(); if ( type->Tag() == TYPE_FILE ) { diff --git a/src/Val.h b/src/Val.h index 2890c4c5e8..756d54b518 100644 --- a/src/Val.h +++ b/src/Val.h @@ -32,6 +32,7 @@ #define ICMP_PORT_MASK 0x30000 class Val; +class BroFunc; class Func; class BroFile; class RE_Matcher; diff --git a/src/Var.cc b/src/Var.cc index 16ced341c1..e45378e75a 100644 --- a/src/Var.cc +++ b/src/Var.cc @@ -385,6 +385,7 @@ void begin_func(ID* id, const char* module_name, function_flavor flavor, RecordType* args = t->Args(); int num_args = args->NumFields(); + for ( int i = 0; i < num_args; ++i ) { TypeDecl* arg_i = args->FieldDecl(i); @@ -421,6 +422,7 @@ TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr) const NameExpr* e = static_cast(expr); + // TODO: Do we need to capture these as well? if ( e->Id()->IsGlobal() ) return TC_CONTINUE; @@ -431,17 +433,11 @@ TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr) return TC_CONTINUE; } -void end_func(Stmt* body) +// Gets a function's priority from its Scope's attributes. Errors if it sees any +// problems. +int get_func_priotity(attr_list* attrs) { - int frame_size = current_scope()->Length(); - id_list* inits = current_scope()->GetInits(); - - Scope* scope = pop_scope(); - ID* id = scope->ScopeID(); - int priority = 0; - auto attrs = scope->Attrs(); - if ( attrs ) { loop_over_list(*attrs, i) @@ -473,27 +469,63 @@ void end_func(Stmt* body) priority = v->InternalInt(); } } + return priority; + } - if ( streq(id->Name(), "anonymous-function") ) +void end_func(Stmt* body) + { + std::unique_ptr ingredients = + gather_function_ingredients(body); + pop_scope(); + + if ( streq(ingredients->id->Name(), "anonymous-function") ) { - OuterIDBindingFinder cb(scope); - body->Traverse(&cb); + OuterIDBindingFinder cb(ingredients->scope); + ingredients->body->Traverse(&cb); for ( size_t i = 0; i < cb.outer_id_references.size(); ++i ) cb.outer_id_references[i]->Error( "referencing outer function IDs not supported"); } - if ( id->HasVal() ) - id->ID_Val()->AsFunc()->AddBody(body, inits, frame_size, priority); + if ( ingredients->id->HasVal() ) + ingredients->id->ID_Val()->AsFunc()->AddBody( + ingredients->body, + ingredients->inits, + ingredients->frame_size, + ingredients->priority); else { - Func* f = new BroFunc(id, body, inits, frame_size, priority); - id->SetVal(new Val(f)); - id->SetConst(); + Func* f = new BroFunc( + ingredients->id, + ingredients->body, + ingredients->inits, + ingredients->frame_size, + ingredients->priority); + ingredients->id->SetVal(new Val(f)); + ingredients->id->SetConst(); } - id->ID_Val()->AsFunc()->SetScope(scope); + ingredients->id->ID_Val()->AsFunc()->SetScope(ingredients->scope); + } + +// Gathers all of the information from the current scope needed to build a +// function and collects it into a function_ingredients struct. +std::unique_ptr gather_function_ingredients(Stmt* body) + { + std::unique_ptr ingredients (new function_ingredients); + ingredients->frame_size = current_scope()->Length(); + ingredients->inits = current_scope()->GetInits(); + + ingredients->scope = current_scope(); + ingredients->id = ingredients->scope->ScopeID(); + + auto attrs = ingredients->scope->Attrs(); + + ingredients->priority = get_func_priotity(attrs); + ingredients->body = body; + + return std::move(ingredients); } Val* internal_val(const char* name) @@ -508,6 +540,19 @@ Val* internal_val(const char* name) return rval; } +std::shared_ptr gather_outer_ids(Scope* scope, Stmt* body) + { + OuterIDBindingFinder cb(scope); + body->Traverse(&cb); + + std::shared_ptr idl (new id_list); + + for ( size_t i = 0; i < cb.outer_id_references.size(); ++i ) + idl->append(cb.outer_id_references[i]->Id()); + + return std::move(idl); + } + Val* internal_const_val(const char* name) { ID* id = lookup_ID(name, GLOBAL_MODULE_NAME); diff --git a/src/Var.h b/src/Var.h index 98e0f45163..1c96ff2081 100644 --- a/src/Var.h +++ b/src/Var.h @@ -3,9 +3,12 @@ #ifndef var_h #define var_h +#include // std::unique_ptr + #include "ID.h" #include "Expr.h" #include "Type.h" +#include "Func.h" // function_ingredients class Func; class EventHandlerPtr; @@ -23,6 +26,9 @@ extern void add_type(ID* id, BroType* t, attr_list* attr); extern void begin_func(ID* id, const char* module_name, function_flavor flavor, int is_redef, FuncType* t, attr_list* attrs = nullptr); extern void end_func(Stmt* body); +extern std::unique_ptr + gather_function_ingredients(Stmt* body); +extern std::shared_ptr gather_outer_ids(Scope* scope, Stmt* body); extern Val* internal_val(const char* name); extern Val* internal_const_val(const char* name); // internal error if not const diff --git a/src/parse.y b/src/parse.y index 2861b95dc8..7765a600fb 100644 --- a/src/parse.y +++ b/src/parse.y @@ -272,7 +272,7 @@ bro: else stmts = $2; - // Any objects creates from hereon out should not + // Any objects creates from here on out should not // have file positions associated with them. set_location(no_location); } @@ -1218,16 +1218,42 @@ func_body: ; anonymous_function: - TOK_FUNCTION begin_func func_body - { $$ = new ConstExpr($2->ID_Val()); } + TOK_FUNCTION func_params + { + $$ = current_scope()->GenerateTemporary("lambda-function"); + begin_func($$, current_module.c_str(), FUNC_FLAVOR_FUNCTION, 0, $2); + } + + '{' + { + saved_in_init.push_back(in_init); + in_init = 0; + } + + stmt_list + { + in_init = saved_in_init.back(); + saved_in_init.pop_back(); + } + + '}' + { + // Every time a new LambdaExpr is evaluated it must return a new instance + // of a BroFunc. Here, we collect the ingredients for a function and give + // it to our LambdaExpr. + std::unique_ptr ingredients = + gather_function_ingredients($6); + std::shared_ptr outer_ids = gather_outer_ids(pop_scope(), $6); + + $$ = new LambdaExpr(std::move(ingredients), std::move(outer_ids)); + } ; begin_func: func_params { $$ = current_scope()->GenerateTemporary("anonymous-function"); - begin_func($$, current_module.c_str(), - FUNC_FLAVOR_FUNCTION, 0, $1); + begin_func($$, current_module.c_str(), FUNC_FLAVOR_FUNCTION, 0, $1); } ; diff --git a/testing/btest/Baseline/language.function-closures/out b/testing/btest/Baseline/language.function-closures/out new file mode 100644 index 0000000000..53274969c7 --- /dev/null +++ b/testing/btest/Baseline/language.function-closures/out @@ -0,0 +1,22 @@ +expect: 3 +3 +expect: 5 +5 +expect: 5 +5 +expect: T +T +expect: T +T +expect: 107 +107 +expect: 107 +107 +expect: 100 +100 +expect: 160 +160 +expect: 225 +225 +expect: 225 +225 diff --git a/testing/btest/Baseline/language.outer_param_binding/out b/testing/btest/Baseline/language.outer_param_binding/out index afdc4191cd..a0c60132e4 100644 --- a/testing/btest/Baseline/language.outer_param_binding/out +++ b/testing/btest/Baseline/language.outer_param_binding/out @@ -1,3 +1 @@ -error in /home/robin/bro/master/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 16: referencing outer function IDs not supported (c) -error in /home/robin/bro/master/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 16: referencing outer function IDs not supported (d) -error in /home/robin/bro/master/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 17: referencing outer function IDs not supported (b) +error in /home/zekemedley/Desktop/corelight/lambda_stuff/zeek/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 11 and /home/zekemedley/Desktop/corelight/lambda_stuff/zeek/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 17: already defined (d) diff --git a/testing/btest/language/function-closures.zeek b/testing/btest/language/function-closures.zeek new file mode 100644 index 0000000000..928d0aa4f1 --- /dev/null +++ b/testing/btest/language/function-closures.zeek @@ -0,0 +1,77 @@ +# @TEST-EXEC: zeek -b %INPUT >out +# @TEST-EXEC: btest-diff out + +function make_count_upper (start : count) : function(step : count) : count + { + return function(step : count) : count + { return (start += step); }; + } + +event zeek_init() + { + # basic + local one = make_count_upper(1); + print "expect: 3"; + print one(2); + + # multiple instances + local two = make_count_upper(one(1)); + print "expect: 5"; + print two(1); + print "expect: 5"; + print one(1); + + # deep copies + local c = copy(one); + print "expect: T"; + print c(1) == one(1); + print "expect: T"; + print c(1) == two(2); + + # a little more complicated ... + local cat_dog = 100; + local add_n_and_m = function(n: count) : function(m : count) : function(o : count) : count + { + cat_dog += 1; # segfault here. + return function(m : count) : function(o : count) : count + { return function(o : count) : count + { return n + m + o + cat_dog; }; }; + }; + + local add_m = add_n_and_m(2); + local adder = add_m(2); + + print "expect: 107"; + print adder(2); + + print "expect: 107"; + # deep copies + local ac = copy(adder); + print ac(2); + + # copies closure: + print "expect: 100"; + print cat_dog; + + # complicated - has state across calls + local two_part_adder_maker = function (begin : count) : function (base_step : count) : function ( step : count) : count + { + return function (base_step : count) : function (step : count) : count + { + return function (step : count) : count + { + return (begin += base_step + step); }; }; }; + + local base_step = two_part_adder_maker(100); + local stepper = base_step(50); + print "expect: 160"; + print stepper(10); + local twotwofive = copy(stepper); + print "expect: 225"; + print stepper(15); + + # another copy check + print "expect: 225"; + print twotwofive(15); + } + diff --git a/testing/btest/language/outer_param_binding.zeek b/testing/btest/language/outer_param_binding.zeek index d3587a7cce..1f3f32a1fc 100644 --- a/testing/btest/language/outer_param_binding.zeek +++ b/testing/btest/language/outer_param_binding.zeek @@ -12,6 +12,9 @@ function bar(b: string, c: string) f = [$x=function(a: string) : string { local x = 0; + # Fail here: we've captured the closure. + # d is already defined. + local d = 10; print x; print c, d; return cat(a, " ", b); @@ -24,4 +27,5 @@ function bar(b: string, c: string) event zeek_init() { bar("1", "20"); + bar("1", "20"); }