diff --git a/src/script_opt/CPP/Compile.h b/src/script_opt/CPP/Compile.h index 447427204e..01f5923bfd 100644 --- a/src/script_opt/CPP/Compile.h +++ b/src/script_opt/CPP/Compile.h @@ -724,6 +724,7 @@ private: void GenTypeSwitchCase(const ID* id, int case_offset, bool is_multi); void GenValueSwitchStmt(const Expr* e, const case_list* cases); + void GenWhenStmt(const WhenStmt* w); void GenForStmt(const ForStmt* f); void GenForOverTable(const ExprPtr& tbl, const IDPtr& value_var, const IDPList* loop_vars); void GenForOverVector(const ExprPtr& tbl, const IDPList* loop_vars); diff --git a/src/script_opt/CPP/Exprs.cc b/src/script_opt/CPP/Exprs.cc index fed8ddcbe9..58120b7cd7 100644 --- a/src/script_opt/CPP/Exprs.cc +++ b/src/script_opt/CPP/Exprs.cc @@ -292,6 +292,7 @@ string CPPCompile::GenCallExpr(const CallExpr* c, GenType gt) const auto& t = c->GetType(); auto f = c->Func(); auto args_l = c->Args(); + bool is_async = c->IsInWhen(); auto gen = GenExpr(f, GEN_DONT_CARE); @@ -304,8 +305,8 @@ string CPPCompile::GenCallExpr(const CallExpr* c, GenType gt) bool is_compiled = compiled_simple_funcs.count(id_name) > 0; bool was_compiled = hashed_funcs.count(id_name) > 0; - if ( is_compiled || was_compiled ) - { + if ( ! is_async && (is_compiled || was_compiled) ) + { // Can call directly. string fname; if ( was_compiled ) @@ -340,8 +341,14 @@ string CPPCompile::GenCallExpr(const CallExpr* c, GenType gt) // Indirect call. gen = string("(") + gen + ")->AsFunc()"; + string invoke_func = is_async ? "when_invoke__CPP" : "invoke__CPP"; auto args_list = string(", {") + GenExpr(args_l, GEN_VAL_PTR) + "}"; - auto invoker = string("invoke__CPP(") + gen + args_list + ", f__CPP)"; + auto invoker = invoke_func + "(" + gen + args_list + ", f__CPP"; + + if ( is_async ) + invoker += ", (void*) &" + body_name; + + invoker += ")"; if ( IsNativeType(t) && gt != GEN_VAL_PTR ) return invoker + NativeAccessor(t); @@ -408,12 +415,17 @@ string CPPCompile::GenIndexExpr(const Expr* e, GenType gt) { auto aggr = e->GetOp1(); const auto& aggr_t = aggr->GetType(); + bool inside_when = e->AsIndexExpr()->IsInsideWhen(); string gen; + string func; if ( aggr_t->Tag() == TYPE_TABLE ) - gen = string("index_table__CPP(") + GenExpr(aggr, GEN_NATIVE) + ", {" + - GenExpr(e->GetOp2(), GEN_VAL_PTR) + "})"; + { + func = inside_when ? "when_index_table__CPP" : "index_table__CPP"; + gen = func + "(" + GenExpr(aggr, GEN_NATIVE) + ", {" + GenExpr(e->GetOp2(), GEN_VAL_PTR) + + "})"; + } else if ( aggr_t->Tag() == TYPE_VECTOR ) { @@ -426,12 +438,16 @@ string CPPCompile::GenIndexExpr(const Expr* e, GenType gt) auto& inds = op2->AsListExpr()->Exprs(); auto first = inds[0]; auto last = inds[1]; - gen = string("index_slice(") + GenExpr(aggr, GEN_VAL_PTR) + ".get(), " + + func = inside_when ? "when_index_slice__CPP" : "index_slice"; + gen = func + "(" + GenExpr(aggr, GEN_VAL_PTR) + ".get(), " + GenExpr(first, GEN_NATIVE) + ", " + GenExpr(last, GEN_NATIVE) + ")"; } else - gen = string("index_vec__CPP(") + GenExpr(aggr, GEN_NATIVE) + ", " + - GenExpr(e->GetOp2(), GEN_NATIVE) + ")"; + { + func = inside_when ? "when_index_vec__CPP" : "index_vec__CPP"; + gen = func + "(" + GenExpr(aggr, GEN_NATIVE) + ", " + GenExpr(e->GetOp2(), GEN_NATIVE) + + ")"; + } } else if ( aggr_t->Tag() == TYPE_STRING ) diff --git a/src/script_opt/CPP/Runtime.h b/src/script_opt/CPP/Runtime.h index d872c1d9e3..7763e95229 100644 --- a/src/script_opt/CPP/Runtime.h +++ b/src/script_opt/CPP/Runtime.h @@ -13,6 +13,7 @@ #include "zeek/RE.h" #include "zeek/RunState.h" #include "zeek/Scope.h" +#include "zeek/Trigger.h" #include "zeek/Val.h" #include "zeek/ZeekString.h" #include "zeek/module_util.h" diff --git a/src/script_opt/CPP/RuntimeInitSupport.cc b/src/script_opt/CPP/RuntimeInitSupport.cc index eea927d192..4758659faf 100644 --- a/src/script_opt/CPP/RuntimeInitSupport.cc +++ b/src/script_opt/CPP/RuntimeInitSupport.cc @@ -245,6 +245,16 @@ FuncValPtr lookup_func__CPP(string name, int num_bodies, vector has return make_intrusive(move(sf)); } +IDPtr find_global__CPP(const char* g) + { + auto gl = lookup_ID(g, GLOBAL_MODULE_NAME, false, false, false); + + if ( ! gl ) + reporter->CPPRuntimeError("global %s is missing", g); + + return gl; + } + RecordTypePtr get_record_type__CPP(const char* record_type_name) { IDPtr existing_type; diff --git a/src/script_opt/CPP/RuntimeInitSupport.h b/src/script_opt/CPP/RuntimeInitSupport.h index 6b5ea46014..e7998716e4 100644 --- a/src/script_opt/CPP/RuntimeInitSupport.h +++ b/src/script_opt/CPP/RuntimeInitSupport.h @@ -70,6 +70,10 @@ extern Func* lookup_bif__CPP(const char* bif); extern FuncValPtr lookup_func__CPP(std::string name, int num_bodies, std::vector h, const TypePtr& t); +// Looks for a global with the given name, generating a run-time error +// if not present. +extern IDPtr find_global__CPP(const char* g); + // Returns the record corresponding to the given name, as long as the // name is indeed a record type. Otherwise (or if the name is nil) // creates a new empty record. diff --git a/src/script_opt/CPP/RuntimeOps.cc b/src/script_opt/CPP/RuntimeOps.cc index 35dd89dc6f..a6af5cad3e 100644 --- a/src/script_opt/CPP/RuntimeOps.cc +++ b/src/script_opt/CPP/RuntimeOps.cc @@ -3,8 +3,10 @@ #include "zeek/script_opt/CPP/RuntimeOps.h" #include "zeek/EventRegistry.h" +#include "zeek/Frame.h" #include "zeek/IPAddr.h" #include "zeek/RunState.h" +#include "zeek/Trigger.h" #include "zeek/ZeekString.h" namespace zeek::detail @@ -60,6 +62,50 @@ ValPtr index_string__CPP(const StringValPtr& svp, vector indices) return index_string(svp->AsString(), index_val__CPP(move(indices)).get()); } +ValPtr when_index_table__CPP(const TableValPtr& t, vector indices) + { + auto v = index_table__CPP(t, std::move(indices)); + if ( v && IndexExprWhen::evaluating > 0 ) + IndexExprWhen::results.emplace_back(v); + return v; + } + +ValPtr when_index_vec__CPP(const VectorValPtr& vec, int index) + { + auto v = index_vec__CPP(vec, index); + if ( v && IndexExprWhen::evaluating > 0 ) + IndexExprWhen::results.emplace_back(v); + return v; + } + +ValPtr when_index_slice__CPP(VectorVal* vec, const ListVal* lv) + { + auto v = index_slice(vec, lv); + if ( v && IndexExprWhen::evaluating > 0 ) + IndexExprWhen::results.emplace_back(v); + return v; + } + +ValPtr when_invoke__CPP(Func* f, std::vector args, Frame* frame, void* caller_addr) + { + auto trigger = frame->GetTrigger(); + + if ( trigger ) + { + Val* v = trigger->Lookup(caller_addr); + if ( v ) + return {NewRef{}, v}; + } + + frame->SetTriggerAssoc(caller_addr); + + auto res = f->Invoke(&args, frame); + if ( ! res ) + throw DelayedCallException(); + + return res; + } + ValPtr set_event__CPP(IDPtr g, ValPtr v, EventHandlerPtr& gh) { g->SetVal(v); diff --git a/src/script_opt/CPP/RuntimeOps.h b/src/script_opt/CPP/RuntimeOps.h index 02df049eb9..12432f6662 100644 --- a/src/script_opt/CPP/RuntimeOps.h +++ b/src/script_opt/CPP/RuntimeOps.h @@ -33,6 +33,14 @@ extern ValPtr index_table__CPP(const TableValPtr& t, std::vector indices extern ValPtr index_vec__CPP(const VectorValPtr& vec, int index); extern ValPtr index_string__CPP(const StringValPtr& svp, std::vector indices); +// The same, but for indexing happening inside a "when" clause. +extern ValPtr when_index_table__CPP(const TableValPtr& t, std::vector indices); +extern ValPtr when_index_vec__CPP(const VectorValPtr& vec, int index); + +// For vector slices, we use the existing index_slice(), but we need a +// custom one for those occurring inside a "when" clause. +extern ValPtr when_index_slice__CPP(VectorVal* vec, const ListVal* lv); + // Calls out to the given script or BiF function. A separate function because // of the need to (1) construct the "args" vector using {} initializers, // but (2) needing to have the address of that vector. @@ -41,6 +49,20 @@ inline ValPtr invoke__CPP(Func* f, std::vector args, Frame* frame) return f->Invoke(&args, frame); } +// The same, but raises an interpreter exception if the function does +// not return a value. Used for calls inside "when" conditions. The +// last argument is the address of the calling function; we just need +// it to be distinct to the call, so we can associate a Trigger cache +// with it. +extern ValPtr when_invoke__CPP(Func* f, std::vector args, Frame* frame, void* caller_addr); + +// Thrown when a call inside a "when" delays. +class DelayedCallException : public InterpreterException + { +public: + DelayedCallException() { } + }; + // Assigns the given value to the given global. A separate function because // we also need to return the value, for use in assignment cascades. inline ValPtr set_global__CPP(IDPtr g, ValPtr v) diff --git a/src/script_opt/CPP/Stmts.cc b/src/script_opt/CPP/Stmts.cc index 1e3c89c460..8cc5a5f39c 100644 --- a/src/script_opt/CPP/Stmts.cc +++ b/src/script_opt/CPP/Stmts.cc @@ -66,6 +66,10 @@ void CPPCompile::GenStmt(const Stmt* s) GenSwitchStmt(static_cast(s)); break; + case STMT_WHEN: + GenWhenStmt(static_cast(s)); + break; + case STMT_FOR: GenForStmt(s->AsForStmt()); break; @@ -91,10 +95,6 @@ void CPPCompile::GenStmt(const Stmt* s) case STMT_FALLTHROUGH: break; - case STMT_WHEN: - ASSERT(0); - break; - default: reporter->InternalError("bad statement type in CPPCompile::GenStmt"); } @@ -163,23 +163,26 @@ void CPPCompile::GenReturnStmt(const ReturnStmt* r) { auto e = r->StmtExpr(); - if ( ! ret_type || ! e || e->GetType()->Tag() == TYPE_VOID || in_hook ) + if ( in_hook ) + Emit("return true;"); + + else if ( ! e && ret_type && ret_type->Tag() != TYPE_VOID ) + // This occurs for ExpressionlessReturnOkay() functions. + Emit("return nullptr;"); + + else if ( ! ret_type || ! e || e->GetType()->Tag() == TYPE_VOID ) + Emit("return;"); + + else { - if ( in_hook ) - Emit("return true;"); - else - Emit("return;"); + auto gt = ret_type->Tag() == TYPE_ANY ? GEN_VAL_PTR : GEN_NATIVE; + auto ret = GenExpr(e, gt); - return; + if ( e->GetType()->Tag() == TYPE_ANY ) + ret = GenericValPtrToGT(ret, ret_type, gt); + + Emit("return %s;", ret); } - - auto gt = ret_type->Tag() == TYPE_ANY ? GEN_VAL_PTR : GEN_NATIVE; - auto ret = GenExpr(e, gt); - - if ( e->GetType()->Tag() == TYPE_ANY ) - ret = GenericValPtrToGT(ret, ret_type, gt); - - Emit("return %s;", ret); } void CPPCompile::GenAddStmt(const ExprStmt* es) @@ -384,6 +387,68 @@ void CPPCompile::GenValueSwitchStmt(const Expr* e, const case_list* cases) Emit("}"); } +void CPPCompile::GenWhenStmt(const WhenStmt* w) + { + auto wi = w->Info(); + auto wl = wi->Lambda(); + + if ( ! wl ) + reporter->FatalError("cannot compile deprecated \"when\" statement"); + + auto is_return = wi->IsReturn() ? "true" : "false"; + auto timeout = wi->TimeoutExpr(); + auto timeout_val = timeout ? GenExpr(timeout, GEN_NATIVE) : "-1.0"; + auto loc = w->GetLocationInfo(); + + Emit("{ // begin a new scope for internal variables"); + + Emit("static WhenInfo* CPP__wi = nullptr;"); + Emit("static IDSet CPP__w_globals;"); + + NL(); + + Emit("if ( ! CPP__wi )"); + StartBlock(); + Emit("CPP__wi = new WhenInfo(%s);", is_return); + for ( auto& wg : wi->WhenExprGlobals() ) + Emit("CPP__w_globals.insert(find_global__CPP(\"%s\").get());", wg->Name()); + EndBlock(); + NL(); + + Emit("std::vector CPP__local_aggrs;"); + for ( auto l : wi->WhenExprLocals() ) + if ( IsAggr(l->GetType()) ) + Emit("CPP__local_aggrs.emplace_back(%s);", IDNameStr(l)); + + Emit("CPP__wi->Instantiate(%s);", GenExpr(wi->Lambda(), GEN_NATIVE)); + + // We need a new frame for the trigger to unambiguously associate + // with, in case we're called multiple times with our existing frame. + Emit("auto new_frame = make_intrusive(0, nullptr, nullptr);"); + Emit("auto curr_t = f__CPP->GetTrigger();"); + Emit("auto curr_assoc = f__CPP->GetTriggerAssoc();"); + Emit("new_frame->SetTrigger({NewRef{}, curr_t});"); + Emit("new_frame->SetTriggerAssoc(curr_assoc);"); + + Emit("auto t = new trigger::Trigger(CPP__wi, %s, CPP__w_globals, CPP__local_aggrs, " + "new_frame.get(), " + "nullptr);", + timeout_val); + + auto loc_str = util::fmt("%s:%d-%d", loc->filename, loc->first_line, loc->last_line); + Emit("t->SetName(\"%s\");", loc_str); + + if ( ret_type && ret_type->Tag() != TYPE_VOID ) + { + Emit("ValPtr retval = {NewRef{}, curr_t->Lookup(curr_assoc)};"); + Emit("if ( ! retval )"); + Emit("\tthrow DelayedCallException();"); + Emit("return %s;", GenericValPtrToGT("retval", ret_type, GEN_NATIVE)); + } + + Emit("}"); + } + void CPPCompile::GenForStmt(const ForStmt* f) { Emit("{ // begin a new scope for the internal loop vars"); diff --git a/src/script_opt/CPP/Util.cc b/src/script_opt/CPP/Util.cc index 1afbb0c4a3..510cfd7a00 100644 --- a/src/script_opt/CPP/Util.cc +++ b/src/script_opt/CPP/Util.cc @@ -38,13 +38,6 @@ string scope_prefix(int scope) bool is_CPP_compilable(const ProfileFunc* pf, const char** reason) { - if ( pf->NumWhenStmts() > 0 ) - { - if ( reason ) - *reason = "use of \"when\""; - return false; - } - auto body = pf->ProfiledBody(); if ( body && ! body->GetOptInfo()->is_free_of_conditionals ) {