diff --git a/src/script_opt/CPP/Stmts.cc b/src/script_opt/CPP/Stmts.cc index fdee490711..ff215bceee 100644 --- a/src/script_opt/CPP/Stmts.cc +++ b/src/script_opt/CPP/Stmts.cc @@ -420,9 +420,9 @@ void CPPCompile::GenWhenStmt(const WhenStmt* w) NL(); Emit("std::vector CPP__local_aggrs;"); - for ( auto l : wi->WhenExprLocals() ) + for ( auto& l : wi->WhenExprLocals() ) if ( IsAggr(l->GetType()) ) - Emit("CPP__local_aggrs.emplace_back(%s);", IDNameStr(l)); + Emit("CPP__local_aggrs.emplace_back(%s);", IDNameStr(l.get())); Emit("CPP__wi->Instantiate(%s);", GenExpr(wi->Lambda(), GEN_NATIVE)); diff --git a/src/script_opt/ProfileFunc.cc b/src/script_opt/ProfileFunc.cc index 1da75df4f0..015c395994 100644 --- a/src/script_opt/ProfileFunc.cc +++ b/src/script_opt/ProfileFunc.cc @@ -28,7 +28,23 @@ ProfileFunc::ProfileFunc(const Func* func, const StmtPtr& body, bool _abs_rec_fi profiled_func = func; profiled_body = body.get(); abs_rec_fields = _abs_rec_fields; - Profile(func->GetType().get(), body); + + auto ft = func->GetType()->AsFuncType(); + auto& fcaps = ft->GetCaptures(); + + if ( fcaps ) + { + int offset = 0; + + for ( auto& c : *fcaps ) + { + auto cid = c.Id().get(); + captures.insert(cid); + captures_offsets[cid] = offset++; + } + } + + Profile(ft, body); } ProfileFunc::ProfileFunc(const Stmt* s, bool _abs_rec_fields) @@ -48,10 +64,15 @@ ProfileFunc::ProfileFunc(const Expr* e, bool _abs_rec_fields) { auto func = e->AsLambdaExpr(); - for ( auto oid : func->OuterIDs() ) - captures.insert(oid); + int offset = 0; - Profile(func->GetType()->AsFuncType(), func->Ingredients().Body()); + for ( auto oid : func->OuterIDs() ) + { + captures.insert(oid); + captures_offsets[oid] = offset++; + } + + Profile(func->GetType()->AsFuncType(), func->Ingredients()->Body()); } else @@ -91,9 +112,9 @@ TraversalCode ProfileFunc::PreStmt(const Stmt* s) auto w = s->AsWhenStmt(); auto wi = w->Info(); - auto wl = wi ? wi->Lambda() : nullptr; - if ( wl ) - lambdas.push_back(wl.get()); + + for ( auto wl : wi->WhenNewLocals() ) + when_locals.insert(wl); } break; @@ -171,6 +192,11 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) auto n = e->AsNameExpr(); auto id = n->Id(); + // Turns out that NameExpr's can be constructed using a + // different Type* than that of the identifier itself, + // so be sure we track the latter too. + TrackType(id->GetType()); + if ( id->IsGlobal() ) { globals.insert(id); @@ -179,30 +205,24 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) const auto& t = id->GetType(); if ( t->Tag() == TYPE_FUNC && t->AsFuncType()->Flavor() == FUNC_FLAVOR_EVENT ) events.insert(id->Name()); + + break; } - else - { - // This is a tad ugly. Unfortunately due to the - // weird way that Zeek function *declarations* work, - // there's no reliable way to get the list of - // parameters for a function *definition*, since - // they can have different names than what's present - // in the declaration. So we identify them directly, - // by knowing that they come at the beginning of the - // frame ... and being careful to avoid misconfusing - // a lambda capture with a low frame offset as a - // parameter. - if ( captures.count(id) == 0 && id->Offset() < num_params ) - params.insert(id); + // This is a tad ugly. Unfortunately due to the + // weird way that Zeek function *declarations* work, + // there's no reliable way to get the list of + // parameters for a function *definition*, since + // they can have different names than what's present + // in the declaration. So we identify them directly, + // by knowing that they come at the beginning of the + // frame ... and being careful to avoid misconfusing + // a lambda capture with a low frame offset as a + // parameter. + if ( captures.count(id) == 0 && id->Offset() < num_params ) + params.insert(id); - locals.insert(id); - } - - // Turns out that NameExpr's can be constructed using a - // different Type* than that of the identifier itself, - // so be sure we track the latter too. - TrackType(id->GetType()); + locals.insert(id); break; } @@ -350,7 +370,12 @@ TraversalCode ProfileFunc::PreExpr(const Expr* e) params.insert(i); } - // Avoid recursing into the body. + // In general, we don't want to recurse into the body. + // However, we still want to *profile* it so we can + // identify calls within it. + ProfileFunc body_pf(l->Ingredients()->Body().get(), false); + script_calls.insert(body_pf.ScriptCalls().begin(), body_pf.ScriptCalls().end()); + return TC_ABORTSTMT; } diff --git a/src/script_opt/ProfileFunc.h b/src/script_opt/ProfileFunc.h index 07d56ef39e..861a0c3a3e 100644 --- a/src/script_opt/ProfileFunc.h +++ b/src/script_opt/ProfileFunc.h @@ -100,6 +100,8 @@ public: const IDSet& Globals() const { return globals; } const IDSet& AllGlobals() const { return all_globals; } const IDSet& Locals() const { return locals; } + const IDSet& Captures() const { return captures; } + const auto& CapturesOffsets() const { return captures_offsets; } const IDSet& WhenLocals() const { return when_locals; } const IDSet& Params() const { return params; } const std::unordered_map& Assignees() const { return assignees; } @@ -208,6 +210,9 @@ protected: // If we're profiling a lambda function, this holds the captures. IDSet captures; + // This maps capture identifiers to their offsets. + std::map captures_offsets; + // Constants seen in the function. std::vector constants; @@ -224,7 +229,8 @@ protected: // The same, but in a deterministic order, with duplicates removed. std::vector ordered_types; - // Script functions that this script calls. + // Script functions that this script calls. Includes calls made + // by lambdas and when bodies, as the goal is to identify recursion. std::unordered_set script_calls; // Same for BiF's, though for them we record the corresponding global