diff --git a/src/Expr.cc b/src/Expr.cc index 825f331a8f..d6c6cabe1b 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -4329,6 +4329,8 @@ LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, SetType(ingredients->id->GetType()); + CheckCaptures(); + // Install a dummy version of the function globally for use only // when broker provides a closure. auto dummy_func = make_intrusive( @@ -4375,6 +4377,66 @@ LambdaExpr::LambdaExpr(std::unique_ptr arg_ing, id->SetConst(); } +void LambdaExpr::CheckCaptures() + { + auto ft = type->AsFuncType(); + auto captures = ft->GetCaptures(); + + capture_by_ref = false; + + if ( ! captures ) + { + if ( outer_ids.size() > 0 ) + { + reporter->Warning("use of outer identifiers in lambdas without [] captures is deprecated: %s%s", + outer_ids.size() > 1 ? "e.g., " : "", + outer_ids[0]->Name()); + capture_by_ref = true; + } + + return; + } + + std::set outer_is_matched; + std::set capture_is_matched; + + for ( auto c : *captures ) + { + auto cid = c->id.get(); + + if ( ! cid ) + // This happens for undefined/inappropriate + // identifiers listed in captures. There's + // already been an error message. + continue; + + if ( capture_is_matched.count(cid) > 0 ) + { + ExprError(util::fmt("%s listed multiple times in capture", cid->Name())); + continue; + } + + for ( auto id : outer_ids ) + if ( cid == id ) + { + outer_is_matched.insert(id); + capture_is_matched.insert(cid); + break; + } + } + + for ( auto id : outer_ids ) + if ( outer_is_matched.count(id) == 0 ) + ExprError(util::fmt("%s is used inside lambda but not captured", id->Name())); + + for ( auto c : *captures ) + { + auto cid = c->id.get(); + if ( cid && capture_is_matched.count(cid) == 0 ) + ExprError(util::fmt("%s is captured but not used inside lambda", cid->Name())); + } + } + Scope* LambdaExpr::GetScope() const { return ingredients->scope.get(); @@ -4389,7 +4451,10 @@ ValPtr LambdaExpr::Eval(Frame* f) const ingredients->frame_size, ingredients->priority); - lamb->AddClosure(outer_ids, f); + if ( capture_by_ref ) + lamb->AddClosure(outer_ids, f); + else + lamb->CreateCaptures(f); // Set name to corresponding dummy func. // Allows for lookups by the receiver. @@ -5060,7 +5125,9 @@ ExprPtr check_and_promote_expr(Expr* const e, zeek::Type* t) IntrusivePtr{NewRef{}, e}, IntrusivePtr{NewRef{}, t->AsVectorType()}); - t->Error("type clash", e); + if ( t->Tag() != TYPE_ERROR && et->Tag() != TYPE_ERROR ) + t->Error("type clash", e); + return nullptr; } diff --git a/src/Expr.h b/src/Expr.h index 87aa2e0dd6..4374edc1d3 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -847,9 +847,13 @@ protected: void ExprDescribe(ODesc* d) const override; private: + void CheckCaptures(); + std::unique_ptr ingredients; IDPList outer_ids; + bool capture_by_ref; // if true, use deprecated reference semantics + std::string my_name; }; diff --git a/src/Func.cc b/src/Func.cc index 0062dab5ad..9aff1dee42 100644 --- a/src/Func.cc +++ b/src/Func.cc @@ -319,6 +319,9 @@ ScriptFunc::~ScriptFunc() { if ( ! weak_closure_ref ) Unref(closure); + + delete captures_frame; + delete captures_offset_mapping; } bool ScriptFunc::IsPure() const @@ -472,6 +475,56 @@ ValPtr ScriptFunc::Invoke(zeek::Args* args, Frame* parent) const return result; } +void ScriptFunc::CreateCaptures(Frame* f) + { + auto captures = type->GetCaptures(); + + if ( ! captures ) + return; + + // Create a private Frame to hold the values of captured variables, + // and a mapping from those variables to their offsets in the Frame. + captures_frame = new Frame(captures->size(), this, nullptr); + captures_offset_mapping = new OffsetMap; + + int offset = 0; + for ( auto c : *captures ) + { + auto cid = c->id; + auto v = f->GetElementByID(cid); + + if ( v ) + { + if ( c->deep_copy || ! v->Modifiable() ) + v = v->Clone(); + else + v->Ref(); + + captures_frame->SetElement(offset, v); + } + + (*captures_offset_mapping)[cid->Name()] = offset; + ++offset; + } + } + +void ScriptFunc::SetCaptures(Frame* f) + { + auto captures = type->GetCaptures(); + ASSERT(captures); + + captures_frame = f; + captures_offset_mapping = new OffsetMap; + + int offset = 0; + for ( auto c : *captures ) + { + auto cid = c->id; + (*captures_offset_mapping)[cid->Name()] = offset; + ++offset; + } + } + void ScriptFunc::AddBody(StmtPtr new_body, const std::vector& new_inits, size_t new_frame_size, int priority) diff --git a/src/Func.h b/src/Func.h index 72c529942e..3338b4de71 100644 --- a/src/Func.h +++ b/src/Func.h @@ -161,7 +161,39 @@ public: ValPtr Invoke(zeek::Args* args, Frame* parent) const override; /** - * Adds adds a closure to the function. Closures are cloned and + * Creates a separate frame for captures and initializes its + * elements. The list of captures comes from the ScriptFunc's + * type, so doesn't need to be passed in, just the frame to + * use in evaluating the identifiers. + * + * @param f the frame used for evaluating the captured identifiers + */ + void CreateCaptures(Frame* f); + + /** + * Returns the frame associated with this function for tracking + * captures, or nil if there isn't one. + * + * @return internal frame kept by the function for persisting captures + */ + Frame* GetCapturesFrame() const { return captures_frame; } + + // Same definition as in Frame.h. + using OffsetMap = std::unordered_map; + + /** + * Returns the mapping of captures to slots in the captures frame. + * + * @return pointer to mapping of captures to slots + */ + const OffsetMap* GetCapturesOffsetMap() const + { return captures_offset_mapping; } + + // The following "Closure" methods implement the deprecated + // capture-by-reference functionality. + + /** + * Adds a closure to the function. Closures are cloned and * future calls to ScriptFunc methods will not modify *f*. * * @param ids IDs that are captured by the closure. @@ -218,14 +250,32 @@ protected: */ void SetClosureFrame(Frame* f); + /** + * Uses the given frame for captures, and generates the + * mapping from captured variables to offsets in the frame. + * + * @param f the frame holding the values of capture variables + */ + void SetCaptures(Frame* f); + private: size_t frame_size; // List of the outer IDs used in the function. IDPList outer_ids; + + // The following is used for deprecated capture-by-reference + // closures: // The frame the ScriptFunc was initialized in. Frame* closure = nullptr; bool weak_closure_ref = false; + + // Used for capture-by-copy closures. These persist over the + // function's lifetime, providing quasi-globals that maintain + // state across individual calls to the function. + Frame* captures_frame = nullptr; + + OffsetMap* captures_offset_mapping = nullptr; }; using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args);