diff --git a/src/Expr.cc b/src/Expr.cc index 1ec357e945..2b1b56fffb 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -282,7 +282,7 @@ Val* NameExpr::Eval(Frame* f) const v = id->ID_Val(); else if ( f ) - v = f->NthElement(id->Offset()); + v = f->GetElement(id); else // No frame - evaluating for Simplify() purposes @@ -316,7 +316,7 @@ void NameExpr::Assign(Frame* f, Val* v, Opcode op) if ( id->IsGlobal() ) id->SetVal(v, op); else - f->SetElement(id->Offset(), v); + f->SetElement(id, v); } int NameExpr::IsPure() const @@ -4472,6 +4472,7 @@ FlattenExpr::FlattenExpr(Expr* arg_op) SetType(tl); } + Val* FlattenExpr::Fold(Val* v) const { RecordVal* rv = v->AsRecordVal(); @@ -4991,6 +4992,52 @@ bool CallExpr::DoUnserialize(UnserialInfo* info) return args != 0; } +LambdaExpr::LambdaExpr(std::unique_ptr ingredients, + std::shared_ptr outer_ids) + { + this->ingredients = std::move(ingredients); + this->outer_ids = std::move(outer_ids); + SetType(this->ingredients->id->Type()->Ref()); + } + +Val* LambdaExpr::Eval(Frame* f) const + { + BroFunc* lamb = new BroFunc( + ingredients->id, + ingredients->body, + ingredients->inits, + ingredients->frame_size, + ingredients->priority); + + lamb->AddClosure(outer_ids, f); + + ingredients->id->SetVal((new Val(lamb))->Ref()); + ingredients->id->SetConst(); + ingredients->id->ID_Val()->AsFunc()->SetScope(ingredients->scope); + + return ingredients->id->ID_Val(); + } + +void LambdaExpr::ExprDescribe(ODesc* d) const + { + d->Add("Lambda Expression"); + d->Add("{"); + ingredients->body->Describe(d); + d->Add("}"); + } + +TraversalCode LambdaExpr::Traverse(TraversalCallback* cb) const + { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = ingredients->body->Traverse(cb); + HANDLE_TC_EXPR_POST(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); + } + EventExpr::EventExpr(const char* arg_name, ListExpr* arg_args) : Expr(EXPR_EVENT) { diff --git a/src/Expr.h b/src/Expr.h index e268f07648..6843c04318 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -12,6 +12,10 @@ #include "Debug.h" #include "EventHandler.h" #include "TraverseTypes.h" +#include "Func.h" // function_ingredients + +#include // std::shared_ptr +#include // std::move typedef enum { EXPR_ANY = -1, @@ -62,6 +66,8 @@ class AssignExpr; class CallExpr; class EventExpr; +struct function_ingredients; + class Expr : public BroObj { public: @@ -997,7 +1003,7 @@ public: protected: friend class Expr; - CallExpr() { func = 0; args = 0; } + CallExpr() { func = 0; args = 0; } void ExprDescribe(ODesc* d) const override; @@ -1007,6 +1013,32 @@ protected: ListExpr* args; }; +/* + Class to handle the creation of anonymous functions with closures. + + Facts: + - LambdaExpr creates a new BroFunc on every call to Eval. + - LambdaExpr must be given all the information to create a BroFunc on + construction except for the closure. + - The closure for created BroFuncs is the frame that the LambdaExpr is + evaluated in. +*/ +class LambdaExpr : public Expr { +public: + LambdaExpr(std::unique_ptr ingredients, + std::shared_ptr outer_ids); + + Val* Eval(Frame* f) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; + +protected: + void ExprDescribe(ODesc* d) const override; + +private: + std::unique_ptr ingredients; + std::shared_ptr outer_ids; +}; + class EventExpr : public Expr { public: EventExpr(const char* name, ListExpr* args); diff --git a/src/Frame.cc b/src/Frame.cc index d065fb440a..3640964829 100644 --- a/src/Frame.cc +++ b/src/Frame.cc @@ -7,6 +7,8 @@ #include "Func.h" #include "Trigger.h" +#include + vector g_frame_stack; Frame::Frame(int arg_size, const BroFunc* func, const val_list* fn_args) @@ -27,12 +29,54 @@ Frame::Frame(int arg_size, const BroFunc* func, const val_list* fn_args) Clear(); } +Frame::Frame(const Frame* other) + { + this->size = other->size; + this->frame = other->frame; + this->function = other->function; + this->func_args = other->func_args; + + this->next_stmt = 0; + this->break_before_next_stmt = false; + this->break_on_return = false; + this->delayed = false; + + // We need to Ref this because the + // destructor will Unref. + if ( other->trigger ) + Ref(other->trigger); + + this->trigger = other->trigger; + this->call = other->call; + } + Frame::~Frame() { Unref(trigger); Release(); } +void Frame::SetElement(int n, Val* v) + { + Unref(frame[n]); + frame[n] = v; + } +void Frame::SetElement(const ID* id, Val* v) + { + SetElement(id->Offset(), v); + } + + +Val* Frame::GetElement(ID* id) const + { + return this->frame[id->Offset()]; + } + +void Frame::AddElement(ID* id, Val* v) + { + this->SetElement(id, v); + } + void Frame::Reset(int startIdx) { for ( int i = startIdx; i < size; ++i ) @@ -109,3 +153,101 @@ void Frame::ClearTrigger() Unref(trigger); trigger = 0; } + +ClosureFrame::ClosureFrame(Frame* closure, Frame* not_closure, + std::shared_ptr outer_ids) : Frame(not_closure) + { + assert(closure); + + this->closure = closure; + Ref(this->closure); + this->body = not_closure; + Ref(this->body); + + // To clone a ClosureFrame we null outer_ids and then copy + // the set over directly, hence the check. + if (outer_ids) + { + // Install the closure IDs + id_list* tmp = outer_ids.get(); + loop_over_list(*tmp, i) + { + ID* id = (*tmp)[i]; + if (id) + this->closure_elements.insert(id->Name()); + } + } + } + +ClosureFrame::~ClosureFrame() + { + Unref(this->closure); + Unref(this->body); + } + +Val* ClosureFrame::GetElement(ID* id) const + { + if (this->closure_elements.find(id->Name()) != this->closure_elements.end()) + return ClosureFrame::GatherFromClosure(this, id); + + return this->NthElement(id->Offset()); + } + +void ClosureFrame::SetElement(const ID* id, Val* v) + { + if (this->closure_elements.find(id->Name()) != this->closure_elements.end()) + ClosureFrame::SetInClosure(this, id, v); + else + this->Frame::SetElement(id->Offset(), v); + } + +Frame* ClosureFrame::Clone() + { + Frame* new_closure = this->closure->Clone(); + Frame* new_regular = this->body->Clone(); + + ClosureFrame* cf = new ClosureFrame(new_closure, new_regular, nullptr); + cf->closure_elements = this->closure_elements; + return cf; + } + + +// Each ClosureFrame knows all of the outer IDs that are used inside of it. This is known at +// parse time. These leverage that. If frame_1 encloses frame_2 then the location of a lookup +// for an outer id in frame_2 can be determined by checking if that id is also an outer id in +// frame_2. If it is not, then frame_2 owns the id and the lookup is done there, otherwise, +// go deeper. + +// Note the useage of dynamic_cast. + + +Val* ClosureFrame::GatherFromClosure(const Frame* start, const ID* id) + { + const ClosureFrame* conductor = dynamic_cast(start); + + auto closure_contains = [] (const ClosureFrame* cf, const ID* i) + { return cf->closure_elements.find(i->Name()) != cf->closure_elements.end(); }; + + if ( ! conductor ) + return start->NthElement(id->Offset()); + + if (closure_contains(conductor, id)) + return ClosureFrame::GatherFromClosure(conductor->closure, id); + + return conductor->NthElement(id->Offset()); + } + +void ClosureFrame::SetInClosure(Frame* start, const ID* id, Val* val) + { + ClosureFrame* conductor = dynamic_cast(start); + + auto closure_contains = [] (const ClosureFrame* cf, const ID* i) + { return cf->closure_elements.find(i->Name()) != cf->closure_elements.end(); }; + + if ( ! conductor ) + start->SetElement(id->Offset(), val); + else if (closure_contains(conductor, id)) + ClosureFrame::SetInClosure(conductor->closure, id, val); + else + conductor->Frame::SetElement(id->Offset(), val); + } diff --git a/src/Frame.h b/src/Frame.h index 1469543e10..93d9677844 100644 --- a/src/Frame.h +++ b/src/Frame.h @@ -4,25 +4,31 @@ #define frame_h #include -using namespace std; +#include +#include #include "Val.h" +using namespace std; + class BroFunc; class Trigger; class CallExpr; +class Val; class Frame : public BroObj { public: Frame(int size, const BroFunc* func, const val_list *fn_args); + // The constructed frame becomes a view of the input frame. No copying is done. + Frame(const Frame* other); ~Frame() override; - Val* NthElement(int n) { return frame[n]; } - void SetElement(int n, Val* v) - { - Unref(frame[n]); - frame[n] = v; - } + Val* NthElement(int n) const { return frame[n]; } + void SetElement(int n, Val* v); + virtual void SetElement(const ID* id, Val* v); + + virtual Val* GetElement(ID* id) const; + void AddElement(ID* id, Val* v); void Reset(int startIdx); void Release(); @@ -49,7 +55,9 @@ public: bool BreakOnReturn() const { return break_on_return; } // Deep-copies values. - Frame* Clone(); + virtual Frame* Clone(); + // Only deep-copies values corresponding to requested IDs. + Frame* SelectiveClone(id_list* selection); // If the frame is run in the context of a trigger condition evaluation, // the trigger needs to be registered. @@ -82,6 +90,61 @@ protected: bool delayed; }; +/* +Class that allows for lookups in both a closure frame and a regular frame +according to a list of IDs passed into the constructor. + +Facts: + - A ClosureFrame is created from two frames: a closure and a regular frame. + - ALL operations except GetElement operations operate on the regular frame. + - A ClosureFrame requires a list of outside ID's captured by the closure. + - Get operations on those IDs will be performed on the closure frame. + +ClosureFrame allows functions that generate functions to be passed between +different sized frames and still properly capture their closures. It also allows for +cleaner handling of closures. +*/ +class ClosureFrame : public Frame { +public: + ClosureFrame(Frame* closure, Frame* body, + std::shared_ptr outer_ids); + ~ClosureFrame() override; + Val* GetElement(ID* id) const override; + void SetElement(const ID* id, Val* v) override; + Frame* Clone() override; + +private: + Frame* closure; + Frame* body; + + // Searches this frame and all sub-frame's closures for a value corresponding + // to the id. + static Val* GatherFromClosure(const Frame* start, const ID* id); + // Moves through the closure frames and associates val with id. + static void SetInClosure(Frame* start, const ID* id, Val* val); + + // Hashes c style strings. The strings need to be null-terminated. + struct const_char_hasher { + size_t operator()(const char* in) const + { + // http://www.cse.yorku.ca/~oz/hash.html + size_t h = 5381; + int c; + + while ((c = *in++)) + h = ((h << 5) + h) + c; + + return h; + } + }; + + // NOTE: In a perfect world this would be best done with a trie or bloom + // filter. We only need to check if things are NOT in the closure. + // In reality though the size of a closure is small enough that operatons are + // fairly quick anyway. + std::unordered_set closure_elements; +}; + extern vector g_frame_stack; #endif diff --git a/src/Func.cc b/src/Func.cc index 90515a0f8f..16debd0fb8 100644 --- a/src/Func.cc +++ b/src/Func.cc @@ -232,6 +232,15 @@ bool Func::DoUnserialize(UnserialInfo* info) return true; } +Val* Func::DoClone() + { + // By default, ok just to return a reference. Func does not have any "state". + // That is different across instances. + Val* v = new Val(this); + Ref(v); + return v; + } + void Func::DescribeDebug(ODesc* d, const val_list* args) const { d->Add(Name()); @@ -369,6 +378,8 @@ BroFunc::~BroFunc() { for ( unsigned int i = 0; i < bodies.size(); ++i ) Unref(bodies[i].stmts); + + Unref(this->closure); } int BroFunc::IsPure() const @@ -411,7 +422,13 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const return Flavor() == FUNC_FLAVOR_HOOK ? val_mgr->GetTrue() : 0; } + // f will hold the closure & function's values Frame* f = new Frame(frame_size, this, args); + if (this->closure) + { + assert(outer_ids); + f = new ClosureFrame(this->closure, f, this->outer_ids); + } // Hand down any trigger. if ( parent ) @@ -439,19 +456,18 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const for ( size_t i = 0; i < bodies.size(); ++i ) { if ( sample_logger ) - sample_logger->LocationSeen( - bodies[i].stmts->GetLocationInfo()); + sample_logger->LocationSeen(bodies[i].stmts->GetLocationInfo()); Unref(result); + // Fill in the rest of the frame with the function's arguments. loop_over_list(*args, j) { Val* arg = (*args)[j]; if ( f->NthElement(j) != arg ) { - // Either not yet set, or somebody reassigned - // the frame slot. + // Either not yet set, or somebody reassigned the frame slot. Ref(arg); f->SetElement(j, arg); } @@ -468,14 +484,18 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const { // Already reported, but now determine whether to unwind further. if ( Flavor() == FUNC_FLAVOR_FUNCTION ) + { + Unref(f); + Unref(result); throw; + } // Continue exec'ing remaining bodies of hooks/events. continue; } if ( f->HasDelayed() ) - { + { assert(! result); assert(parent); parent->SetDelayed(); @@ -517,7 +537,7 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const (flow != FLOW_RETURN /* we fell off the end */ || ! result /* explicit return with no result */) && ! f->HasDelayed() ) - reporter->Warning("non-void function returns without a value: %s", + reporter->Warning("non-void function returning without a value: %s", Name()); if ( result && g_trace_state.DoTrace() ) @@ -529,6 +549,7 @@ Val* BroFunc::Call(val_list* args, Frame* parent) const } g_frame_stack.pop_back(); + Unref(f); return result; @@ -559,6 +580,65 @@ void BroFunc::AddBody(Stmt* new_body, id_list* new_inits, int new_frame_size, sort(bodies.begin(), bodies.end()); } +void BroFunc::AddClosure(std::shared_ptr ids, Frame* f) + { + // Order matters here. + this->SetOuterIDs(ids); + this->SetClosureFrame(f); + } + +void BroFunc::SetClosureFrame(Frame* f) + { + if (this->closure) + reporter->InternalError + ("Tried to override closure for BroFunc %s.", this->Name()); + + this->closure = f ? f->Clone() : nullptr; + } + +void BroFunc::ShiftOffsets(int shift, std::shared_ptr idl) + { + id_list* tmp = idl.get(); + if (! idl || shift == 0) + { + // Nothing to do here. + return; + } + + loop_over_list(*tmp, i) + { + ID* id = (*tmp)[i]; + id->SetOffset(id->Offset() + shift); + } + } + +Val* BroFunc::DoClone() + { + // A BroFunc could hold a closure. In this case a clone of it must copy this + // store a copy of this closure. + if ( ! this->closure ) + { + return Func::DoClone(); + } + else + { + BroFunc* other = new BroFunc(); + + other->bodies = this->bodies; + other->scope = this->scope; + other->kind = this->kind; + other->type = this->type; + other->name = this->name; + other->unique_id = this->unique_id; + other->unique_ids = this->unique_ids; + other->frame_size = this->frame_size; + other->closure = this->closure->Clone(); + other->outer_ids = this->outer_ids; + + return new Val(other); + } + } + void BroFunc::Describe(ODesc* d) const { d->Add(Name()); @@ -578,7 +658,8 @@ Stmt* BroFunc::AddInits(Stmt* body, id_list* inits) return body; StmtList* stmt_series = new StmtList; - stmt_series->Stmts().append(new InitStmt(inits)); + InitStmt* first = new InitStmt(inits); + stmt_series->Stmts().append(first); stmt_series->Stmts().append(body); return stmt_series; diff --git a/src/Func.h b/src/Func.h index 48e0c2e8b8..b5d195eb34 100644 --- a/src/Func.h +++ b/src/Func.h @@ -4,10 +4,13 @@ #define func_h #include +#include // std::shared_ptr, std::unique_ptr #include "BroList.h" #include "Obj.h" #include "Debug.h" +#include "Frame.h" +// #include "Val.h" class Val; class ListExpr; @@ -17,6 +20,8 @@ class Frame; class ID; class CallExpr; +struct CloneState; + class Func : public BroObj { public: @@ -62,6 +67,7 @@ public: // This (un-)serializes only a single body (as given in SerialInfo). bool Serialize(SerialInfo* info) const; static Func* Unserialize(UnserialInfo* info); + virtual Val* DoClone(); virtual TraversalCode Traverse(TraversalCallback* cb) const; @@ -95,9 +101,14 @@ public: int IsPure() const override; Val* Call(val_list* args, Frame* parent) const override; + void AddClosure(std::shared_ptr ids, Frame* f); void AddBody(Stmt* new_body, id_list* new_inits, int new_frame_size, int priority) override; + void fsets(); + + Val* DoClone() override; + int FrameSize() const { return frame_size; } void Describe(ODesc* d) const override; @@ -109,6 +120,23 @@ protected: DECLARE_SERIAL(BroFunc); int frame_size; + +private: + // Shifts the offsets of each id in "idl" by "shift". + static void ShiftOffsets(int shift, std::shared_ptr idl); + + // Makes a deep copy of the input frame and captures it. + void SetClosureFrame(Frame* f); + + void SetOuterIDs(std::shared_ptr ids) + { outer_ids = std::move(ids); } + + // List of the outer IDs used in the function. Shared becase other instances + // would like to use it as well. + std::shared_ptr outer_ids = nullptr; + // The frame the Func was initialized in. This is not guaranteed to be + // initialized and should be handled with care. + Frame* closure = nullptr; }; typedef Val* (*built_in_func)(Frame* frame, val_list* args); @@ -146,6 +174,18 @@ struct CallInfo { const val_list* args; }; +// Struct that collects the arguments for a Func. +// Used for BroFuncs with closures. +struct function_ingredients + { + ID* id; + Stmt* body; + id_list* inits; + int frame_size; + int priority; + Scope* scope; + }; + extern vector call_stack; extern std::string render_call_stack(); diff --git a/src/Obj.h b/src/Obj.h index 21730ff367..b1ac9a110d 100644 --- a/src/Obj.h +++ b/src/Obj.h @@ -207,7 +207,7 @@ inline void Error(const BroObj* o, const char* msg) inline void Ref(BroObj* o) { - if ( ++o->ref_cnt <= 1 ) + if ( ++(o->ref_cnt) <= 1 ) bad_ref(0); if ( o->ref_cnt == INT_MAX ) bad_ref(1); diff --git a/src/Stmt.cc b/src/Stmt.cc index 5960747d05..4e5b2d5050 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -960,7 +960,7 @@ Val* SwitchStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const if ( matching_id ) { auto cv = cast_value_to_type(v, matching_id->Type()); - f->SetElement(matching_id->Offset(), cv); + f->SetElement(matching_id, cv); } flow = FLOW_NEXT; @@ -1477,10 +1477,10 @@ Val* ForStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const delete k; if ( value_var ) - f->SetElement(value_var->Offset(), current_tev->Value()->Ref()); + f->SetElement(value_var, current_tev->Value()->Ref()); for ( int i = 0; i < ind_lv->Length(); i++ ) - f->SetElement((*loop_vars)[i]->Offset(), ind_lv->Index(i)->Ref()); + f->SetElement((*loop_vars)[i], ind_lv->Index(i)->Ref()); Unref(ind_lv); flow = FLOW_NEXT; @@ -1508,7 +1508,7 @@ Val* ForStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const // Set the loop variable to the current index, and make // another pass over the loop body. - f->SetElement((*loop_vars)[0]->Offset(), + f->SetElement((*loop_vars)[0], val_mgr->GetCount(i)); flow = FLOW_NEXT; ret = body->Exec(f, flow); @@ -1523,7 +1523,7 @@ Val* ForStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const for ( int i = 0; i < sval->Len(); ++i ) { - f->SetElement((*loop_vars)[0]->Offset(), + f->SetElement((*loop_vars)[0], new StringVal(1, (const char*) sval->Bytes() + i)); flow = FLOW_NEXT; ret = body->Exec(f, flow); @@ -2084,7 +2084,7 @@ Val* InitStmt::Exec(Frame* f, stmt_flow_type& flow) const break; } - f->SetElement(aggr->Offset(), v); + f->SetElement(aggr, v); } return 0; diff --git a/src/Val.cc b/src/Val.cc index bb9a3d1601..77c375da3a 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -103,7 +103,7 @@ Val* Val::DoClone(CloneState* state) // Functions and files. There aren't any derived classes. if ( type->Tag() == TYPE_FUNC ) // Immutable. - return Ref(); + return this->AsFunc()->DoClone(); if ( type->Tag() == TYPE_FILE ) { diff --git a/src/Val.h b/src/Val.h index 2890c4c5e8..756d54b518 100644 --- a/src/Val.h +++ b/src/Val.h @@ -32,6 +32,7 @@ #define ICMP_PORT_MASK 0x30000 class Val; +class BroFunc; class Func; class BroFile; class RE_Matcher; diff --git a/src/Var.cc b/src/Var.cc index 16ced341c1..e45378e75a 100644 --- a/src/Var.cc +++ b/src/Var.cc @@ -385,6 +385,7 @@ void begin_func(ID* id, const char* module_name, function_flavor flavor, RecordType* args = t->Args(); int num_args = args->NumFields(); + for ( int i = 0; i < num_args; ++i ) { TypeDecl* arg_i = args->FieldDecl(i); @@ -421,6 +422,7 @@ TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr) const NameExpr* e = static_cast(expr); + // TODO: Do we need to capture these as well? if ( e->Id()->IsGlobal() ) return TC_CONTINUE; @@ -431,17 +433,11 @@ TraversalCode OuterIDBindingFinder::PreExpr(const Expr* expr) return TC_CONTINUE; } -void end_func(Stmt* body) +// Gets a function's priority from its Scope's attributes. Errors if it sees any +// problems. +int get_func_priotity(attr_list* attrs) { - int frame_size = current_scope()->Length(); - id_list* inits = current_scope()->GetInits(); - - Scope* scope = pop_scope(); - ID* id = scope->ScopeID(); - int priority = 0; - auto attrs = scope->Attrs(); - if ( attrs ) { loop_over_list(*attrs, i) @@ -473,27 +469,63 @@ void end_func(Stmt* body) priority = v->InternalInt(); } } + return priority; + } - if ( streq(id->Name(), "anonymous-function") ) +void end_func(Stmt* body) + { + std::unique_ptr ingredients = + gather_function_ingredients(body); + pop_scope(); + + if ( streq(ingredients->id->Name(), "anonymous-function") ) { - OuterIDBindingFinder cb(scope); - body->Traverse(&cb); + OuterIDBindingFinder cb(ingredients->scope); + ingredients->body->Traverse(&cb); for ( size_t i = 0; i < cb.outer_id_references.size(); ++i ) cb.outer_id_references[i]->Error( "referencing outer function IDs not supported"); } - if ( id->HasVal() ) - id->ID_Val()->AsFunc()->AddBody(body, inits, frame_size, priority); + if ( ingredients->id->HasVal() ) + ingredients->id->ID_Val()->AsFunc()->AddBody( + ingredients->body, + ingredients->inits, + ingredients->frame_size, + ingredients->priority); else { - Func* f = new BroFunc(id, body, inits, frame_size, priority); - id->SetVal(new Val(f)); - id->SetConst(); + Func* f = new BroFunc( + ingredients->id, + ingredients->body, + ingredients->inits, + ingredients->frame_size, + ingredients->priority); + ingredients->id->SetVal(new Val(f)); + ingredients->id->SetConst(); } - id->ID_Val()->AsFunc()->SetScope(scope); + ingredients->id->ID_Val()->AsFunc()->SetScope(ingredients->scope); + } + +// Gathers all of the information from the current scope needed to build a +// function and collects it into a function_ingredients struct. +std::unique_ptr gather_function_ingredients(Stmt* body) + { + std::unique_ptr ingredients (new function_ingredients); + ingredients->frame_size = current_scope()->Length(); + ingredients->inits = current_scope()->GetInits(); + + ingredients->scope = current_scope(); + ingredients->id = ingredients->scope->ScopeID(); + + auto attrs = ingredients->scope->Attrs(); + + ingredients->priority = get_func_priotity(attrs); + ingredients->body = body; + + return std::move(ingredients); } Val* internal_val(const char* name) @@ -508,6 +540,19 @@ Val* internal_val(const char* name) return rval; } +std::shared_ptr gather_outer_ids(Scope* scope, Stmt* body) + { + OuterIDBindingFinder cb(scope); + body->Traverse(&cb); + + std::shared_ptr idl (new id_list); + + for ( size_t i = 0; i < cb.outer_id_references.size(); ++i ) + idl->append(cb.outer_id_references[i]->Id()); + + return std::move(idl); + } + Val* internal_const_val(const char* name) { ID* id = lookup_ID(name, GLOBAL_MODULE_NAME); diff --git a/src/Var.h b/src/Var.h index 98e0f45163..1c96ff2081 100644 --- a/src/Var.h +++ b/src/Var.h @@ -3,9 +3,12 @@ #ifndef var_h #define var_h +#include // std::unique_ptr + #include "ID.h" #include "Expr.h" #include "Type.h" +#include "Func.h" // function_ingredients class Func; class EventHandlerPtr; @@ -23,6 +26,9 @@ extern void add_type(ID* id, BroType* t, attr_list* attr); extern void begin_func(ID* id, const char* module_name, function_flavor flavor, int is_redef, FuncType* t, attr_list* attrs = nullptr); extern void end_func(Stmt* body); +extern std::unique_ptr + gather_function_ingredients(Stmt* body); +extern std::shared_ptr gather_outer_ids(Scope* scope, Stmt* body); extern Val* internal_val(const char* name); extern Val* internal_const_val(const char* name); // internal error if not const diff --git a/src/parse.y b/src/parse.y index 2861b95dc8..7765a600fb 100644 --- a/src/parse.y +++ b/src/parse.y @@ -272,7 +272,7 @@ bro: else stmts = $2; - // Any objects creates from hereon out should not + // Any objects creates from here on out should not // have file positions associated with them. set_location(no_location); } @@ -1218,16 +1218,42 @@ func_body: ; anonymous_function: - TOK_FUNCTION begin_func func_body - { $$ = new ConstExpr($2->ID_Val()); } + TOK_FUNCTION func_params + { + $$ = current_scope()->GenerateTemporary("lambda-function"); + begin_func($$, current_module.c_str(), FUNC_FLAVOR_FUNCTION, 0, $2); + } + + '{' + { + saved_in_init.push_back(in_init); + in_init = 0; + } + + stmt_list + { + in_init = saved_in_init.back(); + saved_in_init.pop_back(); + } + + '}' + { + // Every time a new LambdaExpr is evaluated it must return a new instance + // of a BroFunc. Here, we collect the ingredients for a function and give + // it to our LambdaExpr. + std::unique_ptr ingredients = + gather_function_ingredients($6); + std::shared_ptr outer_ids = gather_outer_ids(pop_scope(), $6); + + $$ = new LambdaExpr(std::move(ingredients), std::move(outer_ids)); + } ; begin_func: func_params { $$ = current_scope()->GenerateTemporary("anonymous-function"); - begin_func($$, current_module.c_str(), - FUNC_FLAVOR_FUNCTION, 0, $1); + begin_func($$, current_module.c_str(), FUNC_FLAVOR_FUNCTION, 0, $1); } ; diff --git a/testing/btest/Baseline/language.function-closures/out b/testing/btest/Baseline/language.function-closures/out new file mode 100644 index 0000000000..53274969c7 --- /dev/null +++ b/testing/btest/Baseline/language.function-closures/out @@ -0,0 +1,22 @@ +expect: 3 +3 +expect: 5 +5 +expect: 5 +5 +expect: T +T +expect: T +T +expect: 107 +107 +expect: 107 +107 +expect: 100 +100 +expect: 160 +160 +expect: 225 +225 +expect: 225 +225 diff --git a/testing/btest/Baseline/language.outer_param_binding/out b/testing/btest/Baseline/language.outer_param_binding/out index afdc4191cd..a0c60132e4 100644 --- a/testing/btest/Baseline/language.outer_param_binding/out +++ b/testing/btest/Baseline/language.outer_param_binding/out @@ -1,3 +1 @@ -error in /home/robin/bro/master/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 16: referencing outer function IDs not supported (c) -error in /home/robin/bro/master/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 16: referencing outer function IDs not supported (d) -error in /home/robin/bro/master/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 17: referencing outer function IDs not supported (b) +error in /home/zekemedley/Desktop/corelight/lambda_stuff/zeek/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 11 and /home/zekemedley/Desktop/corelight/lambda_stuff/zeek/testing/btest/.tmp/language.outer_param_binding/outer_param_binding.zeek, line 17: already defined (d) diff --git a/testing/btest/language/function-closures.zeek b/testing/btest/language/function-closures.zeek new file mode 100644 index 0000000000..928d0aa4f1 --- /dev/null +++ b/testing/btest/language/function-closures.zeek @@ -0,0 +1,77 @@ +# @TEST-EXEC: zeek -b %INPUT >out +# @TEST-EXEC: btest-diff out + +function make_count_upper (start : count) : function(step : count) : count + { + return function(step : count) : count + { return (start += step); }; + } + +event zeek_init() + { + # basic + local one = make_count_upper(1); + print "expect: 3"; + print one(2); + + # multiple instances + local two = make_count_upper(one(1)); + print "expect: 5"; + print two(1); + print "expect: 5"; + print one(1); + + # deep copies + local c = copy(one); + print "expect: T"; + print c(1) == one(1); + print "expect: T"; + print c(1) == two(2); + + # a little more complicated ... + local cat_dog = 100; + local add_n_and_m = function(n: count) : function(m : count) : function(o : count) : count + { + cat_dog += 1; # segfault here. + return function(m : count) : function(o : count) : count + { return function(o : count) : count + { return n + m + o + cat_dog; }; }; + }; + + local add_m = add_n_and_m(2); + local adder = add_m(2); + + print "expect: 107"; + print adder(2); + + print "expect: 107"; + # deep copies + local ac = copy(adder); + print ac(2); + + # copies closure: + print "expect: 100"; + print cat_dog; + + # complicated - has state across calls + local two_part_adder_maker = function (begin : count) : function (base_step : count) : function ( step : count) : count + { + return function (base_step : count) : function (step : count) : count + { + return function (step : count) : count + { + return (begin += base_step + step); }; }; }; + + local base_step = two_part_adder_maker(100); + local stepper = base_step(50); + print "expect: 160"; + print stepper(10); + local twotwofive = copy(stepper); + print "expect: 225"; + print stepper(15); + + # another copy check + print "expect: 225"; + print twotwofive(15); + } + diff --git a/testing/btest/language/outer_param_binding.zeek b/testing/btest/language/outer_param_binding.zeek index d3587a7cce..1f3f32a1fc 100644 --- a/testing/btest/language/outer_param_binding.zeek +++ b/testing/btest/language/outer_param_binding.zeek @@ -12,6 +12,9 @@ function bar(b: string, c: string) f = [$x=function(a: string) : string { local x = 0; + # Fail here: we've captured the closure. + # d is already defined. + local d = 10; print x; print c, d; return cat(a, " ", b); @@ -24,4 +27,5 @@ function bar(b: string, c: string) event zeek_init() { bar("1", "20"); + bar("1", "20"); }