more extensive ZAM inlining & compilation of lambdas

This commit is contained in:
Vern Paxson 2023-07-13 12:03:34 -07:00
parent b9949560c6
commit 1ff490b41c
10 changed files with 156 additions and 42 deletions

View file

@ -4825,10 +4825,19 @@ ScopePtr LambdaExpr::GetScope() const
return ingredients->Scope(); return ingredients->Scope();
} }
void LambdaExpr::ReplaceBody(StmtPtr new_body)
{
ingredients->ReplaceBody(std::move(new_body));
}
ValPtr LambdaExpr::Eval(Frame* f) const ValPtr LambdaExpr::Eval(Frame* f) const
{ {
auto lamb = make_intrusive<ScriptFunc>(ingredients->GetID()); auto lamb = make_intrusive<ScriptFunc>(ingredients->GetID());
// Use the primary function as the source of the frame size
// and function body, rather than the ingredients, since script
// optimization might have changed the former but not the latter.
lamb->SetFrameSize(primary_func->FrameSize());
StmtPtr body = primary_func->GetBodies()[0].stmts; StmtPtr body = primary_func->GetBodies()[0].stmts;
if ( run_state::is_parsing ) if ( run_state::is_parsing )

View file

@ -1482,6 +1482,8 @@ public:
const FunctionIngredientsPtr& Ingredients() const { return ingredients; } const FunctionIngredientsPtr& Ingredients() const { return ingredients; }
void ReplaceBody(StmtPtr new_body);
bool IsReduced(Reducer* c) const override; bool IsReduced(Reducer* c) const override;
bool HasReducedOps(Reducer* c) const override; bool HasReducedOps(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;

View file

@ -375,6 +375,7 @@ public:
const IDPtr& GetID() const { return id; } const IDPtr& GetID() const { return id; }
const StmtPtr& Body() const { return body; } const StmtPtr& Body() const { return body; }
void ReplaceBody(StmtPtr new_body) { body = std::move(new_body); }
const auto& Inits() const { return inits; } const auto& Inits() const { return inits; }
void ClearInits() { inits.clear(); } void ClearInits() { inits.clear(); }

View file

@ -2717,6 +2717,7 @@ ExprPtr InlineExpr::Reduce(Reducer* c, StmtPtr& red_stmt)
red_stmt = nullptr; red_stmt = nullptr;
auto args_list = args->Exprs(); auto args_list = args->Exprs();
auto ret_val = c->PushInlineBlock(type);
loop_over_list(args_list, i) loop_over_list(args_list, i)
{ {
@ -2730,7 +2731,6 @@ ExprPtr InlineExpr::Reduce(Reducer* c, StmtPtr& red_stmt)
red_stmt = MergeStmts(red_stmt, arg_red_stmt, assign_stmt); red_stmt = MergeStmts(red_stmt, arg_red_stmt, assign_stmt);
} }
auto ret_val = c->PushInlineBlock(type);
body = body->Reduce(c); body = body->Reduce(c);
c->PopInlineBlock(); c->PopInlineBlock();

View file

@ -170,11 +170,6 @@ ExprPtr Inliner::CheckForInlining(CallExprPtr c)
// We don't inline indirect calls. // We don't inline indirect calls.
return c; return c;
if ( c->IsInWhen() )
// Don't inline these, as doing so requires propagating
// the in-when attribute to the inlined function body.
return c;
auto n = f->AsNameExpr(); auto n = f->AsNameExpr();
auto func = n->Id(); auto func = n->Id();
@ -190,22 +185,38 @@ ExprPtr Inliner::CheckForInlining(CallExprPtr c)
if ( function->GetKind() != Func::SCRIPT_FUNC ) if ( function->GetKind() != Func::SCRIPT_FUNC )
return c; return c;
// Check for mismatches in argument count due to single-arg-of-type-any
// loophole used for variadic BiFs.
if ( function->GetType()->Params()->NumFields() == 1 && c->Args()->Exprs().size() != 1 )
return c;
auto func_vf = static_cast<ScriptFunc*>(function); auto func_vf = static_cast<ScriptFunc*>(function);
if ( inline_ables.count(func_vf) == 0 ) if ( inline_ables.count(func_vf) == 0 )
return c; return c;
if ( c->IsInWhen() )
{
// Don't inline these, as doing so requires propagating
// the in-when attribute to the inlined function body.
skipped_inlining.insert(func_vf);
return c;
}
// Check for mismatches in argument count due to single-arg-of-type-any
// loophole used for variadic BiFs. (The issue isn't calls to the
// BiFs, which won't happen here, but instead to script functions that
// are misusing/abusing the loophole.)
if ( function->GetType()->Params()->NumFields() == 1 && c->Args()->Exprs().size() != 1 )
{
skipped_inlining.insert(func_vf);
return c;
}
// We're going to inline the body, unless it's too large. // We're going to inline the body, unless it's too large.
auto body = func_vf->GetBodies()[0].stmts; // there's only 1 body auto body = func_vf->GetBodies()[0].stmts; // there's only 1 body
auto oi = body->GetOptInfo(); auto oi = body->GetOptInfo();
if ( num_stmts + oi->num_stmts + num_exprs + oi->num_exprs > MAX_INLINE_SIZE ) if ( num_stmts + oi->num_stmts + num_exprs + oi->num_exprs > MAX_INLINE_SIZE )
return nullptr; {
skipped_inlining.insert(func_vf);
return nullptr; // signals "stop inlining"
}
num_stmts += oi->num_stmts; num_stmts += oi->num_stmts;
num_exprs += oi->num_exprs; num_exprs += oi->num_exprs;

View file

@ -32,7 +32,10 @@ public:
ExprPtr CheckForInlining(CallExprPtr c); ExprPtr CheckForInlining(CallExprPtr c);
// True if the given function has been inlined. // True if the given function has been inlined.
bool WasInlined(const Func* f) { return inline_ables.count(f) > 0; } bool WasInlined(const Func* f)
{
return inline_ables.count(f) > 0 && skipped_inlining.count(f) == 0;
}
protected: protected:
// Driver routine that analyzes all of the script functions and // Driver routine that analyzes all of the script functions and
@ -49,6 +52,10 @@ protected:
// Functions that we've determined to be suitable for inlining. // Functions that we've determined to be suitable for inlining.
std::unordered_set<const Func*> inline_ables; std::unordered_set<const Func*> inline_ables;
// Functions that we didn't fully inline, so require separate
// compilation.
std::unordered_set<const Func*> skipped_inlining;
// As we do inlining for a given function, this tracks the // As we do inlining for a given function, this tracks the
// largest frame size of any inlined function. // largest frame size of any inlined function.
int max_inlined_frame_size; int max_inlined_frame_size;

View file

@ -4,19 +4,36 @@
#include "zeek/Desc.h" #include "zeek/Desc.h"
#include "zeek/Expr.h" #include "zeek/Expr.h"
#include "zeek/Func.h"
#include "zeek/ID.h" #include "zeek/ID.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/Scope.h" #include "zeek/Scope.h"
#include "zeek/Stmt.h" #include "zeek/Stmt.h"
#include "zeek/Var.h" #include "zeek/Var.h"
#include "zeek/script_opt/ExprOptInfo.h" #include "zeek/script_opt/ExprOptInfo.h"
#include "zeek/script_opt/ProfileFunc.h"
#include "zeek/script_opt/StmtOptInfo.h" #include "zeek/script_opt/StmtOptInfo.h"
#include "zeek/script_opt/TempVar.h" #include "zeek/script_opt/TempVar.h"
namespace zeek::detail namespace zeek::detail
{ {
Reducer::Reducer(const ScriptFunc* func)
{
auto& ft = func->GetType();
// Track the parameters so we don't remap them.
int num_params = ft->Params()->NumFields();
auto& scope_vars = current_scope()->OrderedVars();
for ( auto i = 0; i < num_params; ++i )
tracked_ids.insert(scope_vars[i].get());
// Now include any captures.
if ( ft->GetCaptures() )
for ( auto& c : *ft->GetCaptures() )
tracked_ids.insert(c.Id().get());
}
StmtPtr Reducer::Reduce(StmtPtr s) StmtPtr Reducer::Reduce(StmtPtr s)
{ {
reduction_root = std::move(s); reduction_root = std::move(s);
@ -57,10 +74,9 @@ NameExprPtr Reducer::UpdateName(NameExprPtr n)
return ne; return ne;
} }
bool Reducer::NameIsReduced(const NameExpr* n) const bool Reducer::NameIsReduced(const NameExpr* n)
{ {
auto id = n->Id(); return ID_IsReducedOrTopLevel(n->Id());
return inline_block_level == 0 || id->IsGlobal() || IsTemporary(id) || IsNewLocal(n);
} }
void Reducer::UpdateIDs(IDPList* ids) void Reducer::UpdateIDs(IDPList* ids)
@ -69,7 +85,7 @@ void Reducer::UpdateIDs(IDPList* ids)
{ {
IDPtr id = {NewRef{}, (*ids)[i]}; IDPtr id = {NewRef{}, (*ids)[i]};
if ( ! ID_IsReduced(id) ) if ( ! ID_IsReducedOrTopLevel(id) )
{ {
Unref((*ids)[i]); Unref((*ids)[i]);
(*ids)[i] = UpdateID(id).release(); (*ids)[i] = UpdateID(id).release();
@ -80,7 +96,7 @@ void Reducer::UpdateIDs(IDPList* ids)
void Reducer::UpdateIDs(std::vector<IDPtr>& ids) void Reducer::UpdateIDs(std::vector<IDPtr>& ids)
{ {
for ( auto& id : ids ) for ( auto& id : ids )
if ( ! ID_IsReduced(id) ) if ( ! ID_IsReducedOrTopLevel(id) )
id = UpdateID(id); id = UpdateID(id);
} }
@ -104,15 +120,27 @@ bool Reducer::IDsAreReduced(const std::vector<IDPtr>& ids) const
IDPtr Reducer::UpdateID(IDPtr id) IDPtr Reducer::UpdateID(IDPtr id)
{ {
if ( ID_IsReduced(id) ) if ( ID_IsReducedOrTopLevel(id) )
return id; return id;
return FindNewLocal(id); return FindNewLocal(id);
} }
bool Reducer::ID_IsReducedOrTopLevel(const ID* id)
{
if ( inline_block_level == 0 )
{
tracked_ids.insert(id);
return true;
}
return ID_IsReduced(id);
}
bool Reducer::ID_IsReduced(const ID* id) const bool Reducer::ID_IsReduced(const ID* id) const
{ {
return inline_block_level == 0 || id->IsGlobal() || IsTemporary(id) || IsNewLocal(id); return inline_block_level == 0 || tracked_ids.count(id) > 0 || id->IsGlobal() ||
IsTemporary(id);
} }
NameExprPtr Reducer::GenInlineBlockName(const IDPtr& id) NameExprPtr Reducer::GenInlineBlockName(const IDPtr& id)
@ -126,6 +154,8 @@ NameExprPtr Reducer::PushInlineBlock(TypePtr type)
{ {
++inline_block_level; ++inline_block_level;
block_locals.emplace_back(std::unordered_map<const ID*, IDPtr>());
if ( ! type || type->Tag() == TYPE_VOID ) if ( ! type || type->Tag() == TYPE_VOID )
return nullptr; return nullptr;
@ -133,11 +163,13 @@ NameExprPtr Reducer::PushInlineBlock(TypePtr type)
ret_id->SetType(type); ret_id->SetType(type);
ret_id->GetOptInfo()->SetTemp(); ret_id->GetOptInfo()->SetTemp();
ret_vars.insert(ret_id.get());
// Track this as a new local *if* we're in the outermost inlining // Track this as a new local *if* we're in the outermost inlining
// block. If we're recursively deeper into inlining, then this // block. If we're recursively deeper into inlining, then this
// variable will get mapped to a local anyway, so no need. // variable will get mapped to a local anyway, so no need.
if ( inline_block_level == 1 ) if ( inline_block_level == 1 )
new_locals.insert(ret_id.get()); AddNewLocal(ret_id);
return GenInlineBlockName(ret_id); return GenInlineBlockName(ret_id);
} }
@ -145,6 +177,18 @@ NameExprPtr Reducer::PushInlineBlock(TypePtr type)
void Reducer::PopInlineBlock() void Reducer::PopInlineBlock()
{ {
--inline_block_level; --inline_block_level;
for ( auto& l : block_locals.back() )
{
auto key = l.first;
auto prev = l.second;
if ( prev )
orig_to_new_locals[key] = prev;
else
orig_to_new_locals.erase(key);
}
block_locals.pop_back();
} }
bool Reducer::SameVal(const Val* v1, const Val* v2) const bool Reducer::SameVal(const Val* v1, const Val* v2) const
@ -750,6 +794,12 @@ IDPtr Reducer::FindNewLocal(const IDPtr& id)
return GenLocal(id); return GenLocal(id);
} }
void Reducer::AddNewLocal(const IDPtr& l)
{
new_locals.insert(l.get());
tracked_ids.insert(l.get());
}
IDPtr Reducer::GenLocal(const IDPtr& orig) IDPtr Reducer::GenLocal(const IDPtr& orig)
{ {
if ( Optimizing() ) if ( Optimizing() )
@ -758,6 +808,9 @@ IDPtr Reducer::GenLocal(const IDPtr& orig)
if ( omitted_stmts.size() > 0 ) if ( omitted_stmts.size() > 0 )
reporter->InternalError("Generating a new local while pruning statements"); reporter->InternalError("Generating a new local while pruning statements");
// Make sure the identifier is not being re-re-mapped.
ASSERT(strchr(orig->Name(), '.') == nullptr);
char buf[8192]; char buf[8192];
int n = new_locals.size(); int n = new_locals.size();
snprintf(buf, sizeof buf, "%s.%d", orig->Name(), n); snprintf(buf, sizeof buf, "%s.%d", orig->Name(), n);
@ -769,9 +822,16 @@ IDPtr Reducer::GenLocal(const IDPtr& orig)
if ( orig->GetOptInfo()->IsTemp() ) if ( orig->GetOptInfo()->IsTemp() )
local_id->GetOptInfo()->SetTemp(); local_id->GetOptInfo()->SetTemp();
new_locals.insert(local_id.get()); IDPtr prev;
if ( orig_to_new_locals.count(orig.get()) )
prev = orig_to_new_locals[orig.get()];
AddNewLocal(local_id);
orig_to_new_locals[orig.get()] = local_id; orig_to_new_locals[orig.get()] = local_id;
if ( ! block_locals.empty() && ret_vars.count(orig.get()) == 0 )
block_locals.back()[orig.get()] = prev;
return local_id; return local_id;
} }

View file

@ -16,7 +16,7 @@ class TempVar;
class Reducer class Reducer
{ {
public: public:
Reducer() { } Reducer(const ScriptFunc* func);
StmtPtr Reduce(StmtPtr s); StmtPtr Reduce(StmtPtr s);
@ -27,7 +27,7 @@ public:
ExprPtr GenTemporaryExpr(const TypePtr& t, ExprPtr rhs); ExprPtr GenTemporaryExpr(const TypePtr& t, ExprPtr rhs);
NameExprPtr UpdateName(NameExprPtr n); NameExprPtr UpdateName(NameExprPtr n);
bool NameIsReduced(const NameExpr* n) const; bool NameIsReduced(const NameExpr* n);
void UpdateIDs(IDPList* ids); void UpdateIDs(IDPList* ids);
bool IDsAreReduced(const IDPList* ids) const; bool IDsAreReduced(const IDPList* ids) const;
@ -39,6 +39,10 @@ public:
bool ID_IsReduced(const IDPtr& id) const { return ID_IsReduced(id.get()); } bool ID_IsReduced(const IDPtr& id) const { return ID_IsReduced(id.get()); }
bool ID_IsReduced(const ID* id) const; bool ID_IsReduced(const ID* id) const;
// A version of ID_IsReduced() that tracks top-level variables, too.
bool ID_IsReducedOrTopLevel(const IDPtr& id) { return ID_IsReducedOrTopLevel(id.get()); }
bool ID_IsReducedOrTopLevel(const ID* id);
// This is called *prior* to pushing a new inline block, in // This is called *prior* to pushing a new inline block, in
// order to generate the equivalent of function parameters. // order to generate the equivalent of function parameters.
NameExprPtr GenInlineBlockName(const IDPtr& id); NameExprPtr GenInlineBlockName(const IDPtr& id);
@ -205,6 +209,8 @@ protected:
IDPtr FindNewLocal(const IDPtr& id); IDPtr FindNewLocal(const IDPtr& id);
IDPtr FindNewLocal(const NameExprPtr& n) { return FindNewLocal(n->IdPtr()); } IDPtr FindNewLocal(const NameExprPtr& n) { return FindNewLocal(n->IdPtr()); }
void AddNewLocal(const IDPtr& l);
// Generate a new local to use in lieu of the original (seen // Generate a new local to use in lieu of the original (seen
// in an inlined block). The difference is that the new // in an inlined block). The difference is that the new
// version has a distinct name and has a correct frame offset // version has a distinct name and has a correct frame offset
@ -228,6 +234,9 @@ protected:
// variable, if it corresponds to one. // variable, if it corresponds to one.
std::unordered_map<const ID*, std::shared_ptr<TempVar>> ids_to_temps; std::unordered_map<const ID*, std::shared_ptr<TempVar>> ids_to_temps;
// Identifiers that we're tracking (and don't want to replace).
IDSet tracked_ids;
// Local variables created during reduction/optimization. // Local variables created during reduction/optimization.
IDSet new_locals; IDSet new_locals;
@ -250,10 +259,18 @@ protected:
// Maps statements to replacements constructed during optimization. // Maps statements to replacements constructed during optimization.
std::unordered_map<const Stmt*, StmtPtr> replaced_stmts; std::unordered_map<const Stmt*, StmtPtr> replaced_stmts;
// Tracks return variables we've created.
IDSet ret_vars;
// Tracks whether we're inside an inline block, and if so then // Tracks whether we're inside an inline block, and if so then
// how deeply. // how deeply.
int inline_block_level = 0; int inline_block_level = 0;
// Tracks locals introduced in the current block, remembering
// their previous replacement value (per "orig_to_new_locals"),
// if any. When we pop the block, we restore the previous mapping.
std::vector<std::unordered_map<const ID*, IDPtr>> block_locals;
// Tracks how deeply we are in "bifurcation", i.e., duplicating // Tracks how deeply we are in "bifurcation", i.e., duplicating
// code for if-else cascades. We need to cap this at a certain // code for if-else cascades. We need to cap this at a certain
// depth or else we can get functions whose size blows up // depth or else we can get functions whose size blows up
@ -265,7 +282,7 @@ protected:
IDSet constant_vars; IDSet constant_vars;
// Statement at which the current reduction started. // Statement at which the current reduction started.
StmtPtr reduction_root = nullptr; StmtPtr reduction_root;
// Statement we're currently working on. // Statement we're currently working on.
const Stmt* curr_stmt = nullptr; const Stmt* curr_stmt = nullptr;

View file

@ -36,7 +36,7 @@ static ZAMCompiler* ZAM = nullptr;
static bool generating_CPP = false; static bool generating_CPP = false;
static std::string CPP_dir; // where to generate C++ code static std::string CPP_dir; // where to generate C++ code
static std::unordered_set<const ScriptFunc*> lambdas; static std::unordered_map<const ScriptFunc*, LambdaExpr*> lambdas;
static std::unordered_set<const ScriptFunc*> when_lambdas; static std::unordered_set<const ScriptFunc*> when_lambdas;
static ScriptFuncPtr global_stmts; static ScriptFuncPtr global_stmts;
@ -52,7 +52,7 @@ void analyze_lambda(LambdaExpr* l)
{ {
auto& pf = l->PrimaryFunc(); auto& pf = l->PrimaryFunc();
analyze_func(pf); analyze_func(pf);
lambdas.insert(pf.get()); lambdas[pf.get()] = l;
} }
void analyze_when_lambda(LambdaExpr* l) void analyze_when_lambda(LambdaExpr* l)
@ -185,7 +185,7 @@ static void optimize_func(ScriptFunc* f, std::shared_ptr<ProfileFunc> pf, ScopeP
push_existing_scope(scope); push_existing_scope(scope);
auto rc = std::make_shared<Reducer>(); auto rc = std::make_shared<Reducer>(f);
auto new_body = rc->Reduce(body); auto new_body = rc->Reduce(body);
if ( reporter->Errors() > 0 ) if ( reporter->Errors() > 0 )
@ -496,9 +496,7 @@ static void analyze_scripts_for_ZAM(std::unique_ptr<ProfileFuncs>& pfs)
if ( inl ) if ( inl )
{ {
for ( auto& f : funcs ) for ( auto& g : pfs->Globals() )
{
for ( const auto& g : f.Profile()->Globals() )
{ {
if ( g->GetType()->Tag() != TYPE_FUNC ) if ( g->GetType()->Tag() != TYPE_FUNC )
continue; continue;
@ -513,14 +511,14 @@ static void analyze_scripts_for_ZAM(std::unique_ptr<ProfileFuncs>& pfs)
func_used_indirectly.insert(func); func_used_indirectly.insert(func);
} }
} }
}
bool did_one = false; bool did_one = false;
for ( auto& f : funcs ) for ( auto& f : funcs )
{ {
auto func = f.Func(); auto func = f.Func();
bool is_lambda = lambdas.count(func) > 0; auto l = lambdas.find(func);
bool is_lambda = l != lambdas.end();
if ( ! analysis_options.only_funcs.empty() || ! analysis_options.only_files.empty() ) if ( ! analysis_options.only_funcs.empty() || ! analysis_options.only_files.empty() )
{ {
@ -539,6 +537,9 @@ static void analyze_scripts_for_ZAM(std::unique_ptr<ProfileFuncs>& pfs)
optimize_func(func, f.ProfilePtr(), f.Scope(), new_body); optimize_func(func, f.ProfilePtr(), f.Scope(), new_body);
f.SetBody(new_body); f.SetBody(new_body);
if ( is_lambda )
l->second->ReplaceBody(new_body);
did_one = true; did_one = true;
} }

View file

@ -1594,7 +1594,13 @@ macro WhenCall(func)
for ( auto i = 0; i < n; ++i ) for ( auto i = 0; i < n; ++i )
args.push_back(aux->ToVal(frame, i)); args.push_back(aux->ToVal(frame, i));
f->SetCall(z.call_expr); f->SetCall(z.call_expr);
/* It's possible that this function will call another that
* itself returns null because *it* is the actual blocker.
* That will set ZAM_error, which we need to ignore.
*/
auto hold_ZAM_error = ZAM_error;
vp = func->Invoke(&args, f); vp = func->Invoke(&args, f);
ZAM_error = hold_ZAM_error;
f->SetTriggerAssoc(current_assoc); f->SetTriggerAssoc(current_assoc);
if ( ! vp ) if ( ! vp )
throw ZAMDelayedCallException(); throw ZAMDelayedCallException();