diff --git a/src/Expr.cc b/src/Expr.cc index bf68c7bc40..9d411c4702 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -4825,10 +4825,19 @@ ScopePtr LambdaExpr::GetScope() const return ingredients->Scope(); } +void LambdaExpr::ReplaceBody(StmtPtr new_body) + { + ingredients->ReplaceBody(std::move(new_body)); + } + ValPtr LambdaExpr::Eval(Frame* f) const { auto lamb = make_intrusive(ingredients->GetID()); + // Use the primary function as the source of the frame size + // and function body, rather than the ingredients, since script + // optimization might have changed the former but not the latter. + lamb->SetFrameSize(primary_func->FrameSize()); StmtPtr body = primary_func->GetBodies()[0].stmts; if ( run_state::is_parsing ) diff --git a/src/Expr.h b/src/Expr.h index 656933ce81..712f386f17 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -1482,6 +1482,8 @@ public: const FunctionIngredientsPtr& Ingredients() const { return ingredients; } + void ReplaceBody(StmtPtr new_body); + bool IsReduced(Reducer* c) const override; bool HasReducedOps(Reducer* c) const override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; diff --git a/src/Func.h b/src/Func.h index bdf725598d..ed0e59aea8 100644 --- a/src/Func.h +++ b/src/Func.h @@ -375,6 +375,7 @@ public: const IDPtr& GetID() const { return id; } const StmtPtr& Body() const { return body; } + void ReplaceBody(StmtPtr new_body) { body = std::move(new_body); } const auto& Inits() const { return inits; } void ClearInits() { inits.clear(); } diff --git a/src/script_opt/Expr.cc b/src/script_opt/Expr.cc index 3b26e9b99a..1ecab536a7 100644 --- a/src/script_opt/Expr.cc +++ b/src/script_opt/Expr.cc @@ -2717,6 +2717,7 @@ ExprPtr InlineExpr::Reduce(Reducer* c, StmtPtr& red_stmt) red_stmt = nullptr; auto args_list = args->Exprs(); + auto ret_val = c->PushInlineBlock(type); loop_over_list(args_list, i) { @@ -2730,7 +2731,6 @@ ExprPtr InlineExpr::Reduce(Reducer* c, StmtPtr& red_stmt) red_stmt = MergeStmts(red_stmt, arg_red_stmt, assign_stmt); } - auto ret_val = c->PushInlineBlock(type); body = body->Reduce(c); c->PopInlineBlock(); diff --git a/src/script_opt/Inline.cc b/src/script_opt/Inline.cc index ebd7d5f0b9..217949663b 100644 --- a/src/script_opt/Inline.cc +++ b/src/script_opt/Inline.cc @@ -170,11 +170,6 @@ ExprPtr Inliner::CheckForInlining(CallExprPtr c) // We don't inline indirect calls. return c; - if ( c->IsInWhen() ) - // Don't inline these, as doing so requires propagating - // the in-when attribute to the inlined function body. - return c; - auto n = f->AsNameExpr(); auto func = n->Id(); @@ -190,22 +185,38 @@ ExprPtr Inliner::CheckForInlining(CallExprPtr c) if ( function->GetKind() != Func::SCRIPT_FUNC ) return c; - // Check for mismatches in argument count due to single-arg-of-type-any - // loophole used for variadic BiFs. - if ( function->GetType()->Params()->NumFields() == 1 && c->Args()->Exprs().size() != 1 ) - return c; - auto func_vf = static_cast(function); if ( inline_ables.count(func_vf) == 0 ) return c; + if ( c->IsInWhen() ) + { + // Don't inline these, as doing so requires propagating + // the in-when attribute to the inlined function body. + skipped_inlining.insert(func_vf); + return c; + } + + // Check for mismatches in argument count due to single-arg-of-type-any + // loophole used for variadic BiFs. (The issue isn't calls to the + // BiFs, which won't happen here, but instead to script functions that + // are misusing/abusing the loophole.) + if ( function->GetType()->Params()->NumFields() == 1 && c->Args()->Exprs().size() != 1 ) + { + skipped_inlining.insert(func_vf); + return c; + } + // We're going to inline the body, unless it's too large. auto body = func_vf->GetBodies()[0].stmts; // there's only 1 body auto oi = body->GetOptInfo(); if ( num_stmts + oi->num_stmts + num_exprs + oi->num_exprs > MAX_INLINE_SIZE ) - return nullptr; + { + skipped_inlining.insert(func_vf); + return nullptr; // signals "stop inlining" + } num_stmts += oi->num_stmts; num_exprs += oi->num_exprs; diff --git a/src/script_opt/Inline.h b/src/script_opt/Inline.h index 738096c919..097878ae38 100644 --- a/src/script_opt/Inline.h +++ b/src/script_opt/Inline.h @@ -32,7 +32,10 @@ public: ExprPtr CheckForInlining(CallExprPtr c); // True if the given function has been inlined. - bool WasInlined(const Func* f) { return inline_ables.count(f) > 0; } + bool WasInlined(const Func* f) + { + return inline_ables.count(f) > 0 && skipped_inlining.count(f) == 0; + } protected: // Driver routine that analyzes all of the script functions and @@ -49,6 +52,10 @@ protected: // Functions that we've determined to be suitable for inlining. std::unordered_set inline_ables; + // Functions that we didn't fully inline, so require separate + // compilation. + std::unordered_set skipped_inlining; + // As we do inlining for a given function, this tracks the // largest frame size of any inlined function. int max_inlined_frame_size; diff --git a/src/script_opt/Reduce.cc b/src/script_opt/Reduce.cc index 2e5fd902eb..b19394de0b 100644 --- a/src/script_opt/Reduce.cc +++ b/src/script_opt/Reduce.cc @@ -4,19 +4,36 @@ #include "zeek/Desc.h" #include "zeek/Expr.h" +#include "zeek/Func.h" #include "zeek/ID.h" #include "zeek/Reporter.h" #include "zeek/Scope.h" #include "zeek/Stmt.h" #include "zeek/Var.h" #include "zeek/script_opt/ExprOptInfo.h" -#include "zeek/script_opt/ProfileFunc.h" #include "zeek/script_opt/StmtOptInfo.h" #include "zeek/script_opt/TempVar.h" namespace zeek::detail { +Reducer::Reducer(const ScriptFunc* func) + { + auto& ft = func->GetType(); + + // Track the parameters so we don't remap them. + int num_params = ft->Params()->NumFields(); + auto& scope_vars = current_scope()->OrderedVars(); + + for ( auto i = 0; i < num_params; ++i ) + tracked_ids.insert(scope_vars[i].get()); + + // Now include any captures. + if ( ft->GetCaptures() ) + for ( auto& c : *ft->GetCaptures() ) + tracked_ids.insert(c.Id().get()); + } + StmtPtr Reducer::Reduce(StmtPtr s) { reduction_root = std::move(s); @@ -57,10 +74,9 @@ NameExprPtr Reducer::UpdateName(NameExprPtr n) return ne; } -bool Reducer::NameIsReduced(const NameExpr* n) const +bool Reducer::NameIsReduced(const NameExpr* n) { - auto id = n->Id(); - return inline_block_level == 0 || id->IsGlobal() || IsTemporary(id) || IsNewLocal(n); + return ID_IsReducedOrTopLevel(n->Id()); } void Reducer::UpdateIDs(IDPList* ids) @@ -69,7 +85,7 @@ void Reducer::UpdateIDs(IDPList* ids) { IDPtr id = {NewRef{}, (*ids)[i]}; - if ( ! ID_IsReduced(id) ) + if ( ! ID_IsReducedOrTopLevel(id) ) { Unref((*ids)[i]); (*ids)[i] = UpdateID(id).release(); @@ -80,7 +96,7 @@ void Reducer::UpdateIDs(IDPList* ids) void Reducer::UpdateIDs(std::vector& ids) { for ( auto& id : ids ) - if ( ! ID_IsReduced(id) ) + if ( ! ID_IsReducedOrTopLevel(id) ) id = UpdateID(id); } @@ -104,15 +120,27 @@ bool Reducer::IDsAreReduced(const std::vector& ids) const IDPtr Reducer::UpdateID(IDPtr id) { - if ( ID_IsReduced(id) ) + if ( ID_IsReducedOrTopLevel(id) ) return id; return FindNewLocal(id); } +bool Reducer::ID_IsReducedOrTopLevel(const ID* id) + { + if ( inline_block_level == 0 ) + { + tracked_ids.insert(id); + return true; + } + + return ID_IsReduced(id); + } + bool Reducer::ID_IsReduced(const ID* id) const { - return inline_block_level == 0 || id->IsGlobal() || IsTemporary(id) || IsNewLocal(id); + return inline_block_level == 0 || tracked_ids.count(id) > 0 || id->IsGlobal() || + IsTemporary(id); } NameExprPtr Reducer::GenInlineBlockName(const IDPtr& id) @@ -126,6 +154,8 @@ NameExprPtr Reducer::PushInlineBlock(TypePtr type) { ++inline_block_level; + block_locals.emplace_back(std::unordered_map()); + if ( ! type || type->Tag() == TYPE_VOID ) return nullptr; @@ -133,11 +163,13 @@ NameExprPtr Reducer::PushInlineBlock(TypePtr type) ret_id->SetType(type); ret_id->GetOptInfo()->SetTemp(); + ret_vars.insert(ret_id.get()); + // Track this as a new local *if* we're in the outermost inlining // block. If we're recursively deeper into inlining, then this // variable will get mapped to a local anyway, so no need. if ( inline_block_level == 1 ) - new_locals.insert(ret_id.get()); + AddNewLocal(ret_id); return GenInlineBlockName(ret_id); } @@ -145,6 +177,18 @@ NameExprPtr Reducer::PushInlineBlock(TypePtr type) void Reducer::PopInlineBlock() { --inline_block_level; + + for ( auto& l : block_locals.back() ) + { + auto key = l.first; + auto prev = l.second; + if ( prev ) + orig_to_new_locals[key] = prev; + else + orig_to_new_locals.erase(key); + } + + block_locals.pop_back(); } bool Reducer::SameVal(const Val* v1, const Val* v2) const @@ -750,6 +794,12 @@ IDPtr Reducer::FindNewLocal(const IDPtr& id) return GenLocal(id); } +void Reducer::AddNewLocal(const IDPtr& l) + { + new_locals.insert(l.get()); + tracked_ids.insert(l.get()); + } + IDPtr Reducer::GenLocal(const IDPtr& orig) { if ( Optimizing() ) @@ -758,6 +808,9 @@ IDPtr Reducer::GenLocal(const IDPtr& orig) if ( omitted_stmts.size() > 0 ) reporter->InternalError("Generating a new local while pruning statements"); + // Make sure the identifier is not being re-re-mapped. + ASSERT(strchr(orig->Name(), '.') == nullptr); + char buf[8192]; int n = new_locals.size(); snprintf(buf, sizeof buf, "%s.%d", orig->Name(), n); @@ -769,9 +822,16 @@ IDPtr Reducer::GenLocal(const IDPtr& orig) if ( orig->GetOptInfo()->IsTemp() ) local_id->GetOptInfo()->SetTemp(); - new_locals.insert(local_id.get()); + IDPtr prev; + if ( orig_to_new_locals.count(orig.get()) ) + prev = orig_to_new_locals[orig.get()]; + + AddNewLocal(local_id); orig_to_new_locals[orig.get()] = local_id; + if ( ! block_locals.empty() && ret_vars.count(orig.get()) == 0 ) + block_locals.back()[orig.get()] = prev; + return local_id; } diff --git a/src/script_opt/Reduce.h b/src/script_opt/Reduce.h index 1cd7f36db6..2053a4570e 100644 --- a/src/script_opt/Reduce.h +++ b/src/script_opt/Reduce.h @@ -16,7 +16,7 @@ class TempVar; class Reducer { public: - Reducer() { } + Reducer(const ScriptFunc* func); StmtPtr Reduce(StmtPtr s); @@ -27,7 +27,7 @@ public: ExprPtr GenTemporaryExpr(const TypePtr& t, ExprPtr rhs); NameExprPtr UpdateName(NameExprPtr n); - bool NameIsReduced(const NameExpr* n) const; + bool NameIsReduced(const NameExpr* n); void UpdateIDs(IDPList* ids); bool IDsAreReduced(const IDPList* ids) const; @@ -39,6 +39,10 @@ public: bool ID_IsReduced(const IDPtr& id) const { return ID_IsReduced(id.get()); } bool ID_IsReduced(const ID* id) const; + // A version of ID_IsReduced() that tracks top-level variables, too. + bool ID_IsReducedOrTopLevel(const IDPtr& id) { return ID_IsReducedOrTopLevel(id.get()); } + bool ID_IsReducedOrTopLevel(const ID* id); + // This is called *prior* to pushing a new inline block, in // order to generate the equivalent of function parameters. NameExprPtr GenInlineBlockName(const IDPtr& id); @@ -205,6 +209,8 @@ protected: IDPtr FindNewLocal(const IDPtr& id); IDPtr FindNewLocal(const NameExprPtr& n) { return FindNewLocal(n->IdPtr()); } + void AddNewLocal(const IDPtr& l); + // Generate a new local to use in lieu of the original (seen // in an inlined block). The difference is that the new // version has a distinct name and has a correct frame offset @@ -228,6 +234,9 @@ protected: // variable, if it corresponds to one. std::unordered_map> ids_to_temps; + // Identifiers that we're tracking (and don't want to replace). + IDSet tracked_ids; + // Local variables created during reduction/optimization. IDSet new_locals; @@ -250,10 +259,18 @@ protected: // Maps statements to replacements constructed during optimization. std::unordered_map replaced_stmts; + // Tracks return variables we've created. + IDSet ret_vars; + // Tracks whether we're inside an inline block, and if so then // how deeply. int inline_block_level = 0; + // Tracks locals introduced in the current block, remembering + // their previous replacement value (per "orig_to_new_locals"), + // if any. When we pop the block, we restore the previous mapping. + std::vector> block_locals; + // Tracks how deeply we are in "bifurcation", i.e., duplicating // code for if-else cascades. We need to cap this at a certain // depth or else we can get functions whose size blows up @@ -265,7 +282,7 @@ protected: IDSet constant_vars; // Statement at which the current reduction started. - StmtPtr reduction_root = nullptr; + StmtPtr reduction_root; // Statement we're currently working on. const Stmt* curr_stmt = nullptr; diff --git a/src/script_opt/ScriptOpt.cc b/src/script_opt/ScriptOpt.cc index c8380f6218..67284107a1 100644 --- a/src/script_opt/ScriptOpt.cc +++ b/src/script_opt/ScriptOpt.cc @@ -36,7 +36,7 @@ static ZAMCompiler* ZAM = nullptr; static bool generating_CPP = false; static std::string CPP_dir; // where to generate C++ code -static std::unordered_set lambdas; +static std::unordered_map lambdas; static std::unordered_set when_lambdas; static ScriptFuncPtr global_stmts; @@ -52,7 +52,7 @@ void analyze_lambda(LambdaExpr* l) { auto& pf = l->PrimaryFunc(); analyze_func(pf); - lambdas.insert(pf.get()); + lambdas[pf.get()] = l; } void analyze_when_lambda(LambdaExpr* l) @@ -185,7 +185,7 @@ static void optimize_func(ScriptFunc* f, std::shared_ptr pf, ScopeP push_existing_scope(scope); - auto rc = std::make_shared(); + auto rc = std::make_shared(f); auto new_body = rc->Reduce(body); if ( reporter->Errors() > 0 ) @@ -496,22 +496,19 @@ static void analyze_scripts_for_ZAM(std::unique_ptr& pfs) if ( inl ) { - for ( auto& f : funcs ) + for ( auto& g : pfs->Globals() ) { - for ( const auto& g : f.Profile()->Globals() ) - { - if ( g->GetType()->Tag() != TYPE_FUNC ) - continue; + if ( g->GetType()->Tag() != TYPE_FUNC ) + continue; - auto v = g->GetVal(); - if ( ! v ) - continue; + auto v = g->GetVal(); + if ( ! v ) + continue; - auto func = v->AsFunc(); + auto func = v->AsFunc(); - if ( inl->WasInlined(func) ) - func_used_indirectly.insert(func); - } + if ( inl->WasInlined(func) ) + func_used_indirectly.insert(func); } } @@ -520,7 +517,8 @@ static void analyze_scripts_for_ZAM(std::unique_ptr& pfs) for ( auto& f : funcs ) { auto func = f.Func(); - bool is_lambda = lambdas.count(func) > 0; + auto l = lambdas.find(func); + bool is_lambda = l != lambdas.end(); if ( ! analysis_options.only_funcs.empty() || ! analysis_options.only_files.empty() ) { @@ -539,6 +537,9 @@ static void analyze_scripts_for_ZAM(std::unique_ptr& pfs) optimize_func(func, f.ProfilePtr(), f.Scope(), new_body); f.SetBody(new_body); + if ( is_lambda ) + l->second->ReplaceBody(new_body); + did_one = true; } diff --git a/src/script_opt/ZAM/Ops.in b/src/script_opt/ZAM/Ops.in index 7ab00fbe33..ac606b316f 100644 --- a/src/script_opt/ZAM/Ops.in +++ b/src/script_opt/ZAM/Ops.in @@ -1594,7 +1594,13 @@ macro WhenCall(func) for ( auto i = 0; i < n; ++i ) args.push_back(aux->ToVal(frame, i)); f->SetCall(z.call_expr); + /* It's possible that this function will call another that + * itself returns null because *it* is the actual blocker. + * That will set ZAM_error, which we need to ignore. + */ + auto hold_ZAM_error = ZAM_error; vp = func->Invoke(&args, f); + ZAM_error = hold_ZAM_error; f->SetTriggerAssoc(current_assoc); if ( ! vp ) throw ZAMDelayedCallException();