diff --git a/src/script_opt/GenIDDefs.cc b/src/script_opt/GenIDDefs.cc index d12eb50afd..222272a431 100644 --- a/src/script_opt/GenIDDefs.cc +++ b/src/script_opt/GenIDDefs.cc @@ -196,12 +196,6 @@ TraversalCode GenIDDefs::PreStmt(const Stmt* s) return TC_ABORTSTMT; } - case STMT_WHEN: - { - // ### punt on these for now, need to reflect on bindings. - return TC_ABORTSTMT; - } - default: return TC_CONTINUE; } diff --git a/src/script_opt/UseDefs.cc b/src/script_opt/UseDefs.cc index 50996ef9a9..2f4d328d5c 100644 --- a/src/script_opt/UseDefs.cc +++ b/src/script_opt/UseDefs.cc @@ -301,13 +301,17 @@ UDs UseDefs::PropagateUDs(const Stmt* s, UDs succ_UDs, const Stmt* succ_stmt, bo return UseUDs(s, succ_UDs); case STMT_WHEN: - // ### Once we support compiling functions with "when" - // statements in them, we'll need to revisit this. - // For now, we don't worry about it (because the current - // "when" body semantics of deep-copy frames has different - // implications than potentially switching those shallow-copy - // frames). - return UseUDs(s, succ_UDs); + { + auto w = s->AsWhenStmt(); + auto wi = w->Info(); + auto uds = UD_Union(succ_UDs, ExprUDs(wi->Lambda().get())); + + auto timeout = wi->TimeoutExpr(); + if ( timeout ) + uds = UD_Union(uds, ExprUDs(timeout.get())); + + return CreateUDs(s, uds); + } case STMT_SWITCH: { diff --git a/src/script_opt/ZAM/Compile.h b/src/script_opt/ZAM/Compile.h index edfeff65bf..4945ae0a23 100644 --- a/src/script_opt/ZAM/Compile.h +++ b/src/script_opt/ZAM/Compile.h @@ -138,6 +138,7 @@ private: const ZAMStmt CompileCatchReturn(const CatchReturnStmt* cr); const ZAMStmt CompileStmts(const StmtList* sl); const ZAMStmt CompileInit(const InitStmt* is); + const ZAMStmt CompileWhen(const WhenStmt* ws); const ZAMStmt CompileNext() { return GenGoTo(nexts.back()); } const ZAMStmt CompileBreak() { return GenGoTo(breaks.back()); } diff --git a/src/script_opt/ZAM/Expr.cc b/src/script_opt/ZAM/Expr.cc index 6982292755..d3d387523e 100644 --- a/src/script_opt/ZAM/Expr.cc +++ b/src/script_opt/ZAM/Expr.cc @@ -283,7 +283,18 @@ const ZAMStmt ZAMCompiler::CompileAssignToIndex(const NameExpr* lhs, const Index : IndexVecIntSelectVVV(lhs, n, index); } - return const_aggr ? IndexVCL(lhs, con, indexes_expr) : IndexVVL(lhs, n, indexes_expr); + if ( rhs->IsInsideWhen() ) + { + if ( const_aggr ) + return WhenIndexVCL(lhs, con, indexes_expr); + else + return WhenIndexVVL(lhs, n, indexes_expr); + } + + if ( const_aggr ) + return IndexVCL(lhs, con, indexes_expr); + else + return IndexVVL(lhs, n, indexes_expr); } const ZAMStmt ZAMCompiler::CompileFieldLHSAssignExpr(const FieldLHSAssignExpr* e) @@ -614,26 +625,28 @@ const ZAMStmt ZAMCompiler::CompileInExpr(const NameExpr* n1, const ListExpr* l, return AddInst(z); } -const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, const NameExpr* n2, const ListExpr* l) +const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, const NameExpr* n2, const ListExpr* l, + bool in_when) { - return CompileIndex(n1, FrameSlot(n2), n2->GetType(), l); + return CompileIndex(n1, FrameSlot(n2), n2->GetType(), l, in_when); } -const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n, const ConstExpr* c, const ListExpr* l) +const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n, const ConstExpr* c, const ListExpr* l, + bool in_when) { auto tmp = TempForConst(c); - return CompileIndex(n, tmp, c->GetType(), l); + return CompileIndex(n, tmp, c->GetType(), l, in_when); } const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, int n2_slot, const TypePtr& n2t, - const ListExpr* l) + const ListExpr* l, bool in_when) { ZInstI z; int n = l->Exprs().length(); auto n2tag = n2t->Tag(); - if ( n == 1 ) + if ( n == 1 && ! in_when ) { auto ind = l->Exprs()[0]; auto var_ind = ind->Tag() == EXPR_NAME; @@ -675,12 +688,28 @@ const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, int n2_slot, const T if ( n3 ) { int n3_slot = FrameSlot(n3); - auto zop = is_any ? OP_INDEX_ANY_VEC_VVV : OP_INDEX_VEC_VVV; + + ZOp zop; + if ( in_when ) + zop = OP_WHEN_INDEX_VEC_VVV; + else if ( is_any ) + zop = OP_INDEX_ANY_VEC_VVV; + else + zop = OP_INDEX_VEC_VVV; + z = ZInstI(zop, Frame1Slot(n1, zop), n2_slot, n3_slot); } else { - auto zop = is_any ? OP_INDEX_ANY_VECC_VVV : OP_INDEX_VECC_VVV; + ZOp zop; + + if ( in_when ) + zop = OP_WHEN_INDEX_VECC_VVV; + else if ( is_any ) + zop = OP_INDEX_ANY_VECC_VVV; + else + zop = OP_INDEX_VECC_VVV; + z = ZInstI(zop, Frame1Slot(n1, zop), n2_slot, c); z.op_type = OP_VVV_I3; } @@ -718,13 +747,13 @@ const ZAMStmt ZAMCompiler::CompileIndex(const NameExpr* n1, int n2_slot, const T switch ( n2tag ) { case TYPE_VECTOR: - op = OP_INDEX_VEC_SLICE_VV; + op = in_when ? OP_WHEN_INDEX_VEC_SLICE_VV : OP_INDEX_VEC_SLICE_VV; z = ZInstI(op, Frame1Slot(n1, op), n2_slot); z.SetType(n2t); break; case TYPE_TABLE: - op = OP_TABLE_INDEX_VV; + op = in_when ? OP_WHEN_TABLE_INDEX_VV : OP_TABLE_INDEX_VV; z = ZInstI(op, Frame1Slot(n1, op), n2_slot); z.SetType(n1->GetType()); break; @@ -924,9 +953,10 @@ const ZAMStmt ZAMCompiler::DoCall(const CallExpr* c, const NameExpr* n) int call_case = nargs; bool indirect = ! func_id->IsGlobal() || ! func_id->GetVal(); + bool in_when = c->IsInWhen(); - if ( indirect ) - call_case = -1; // force default of CallN + if ( indirect || in_when ) + call_case = -1; // force default of some flavor of CallN auto nt = n ? n->GetType()->Tag() : TYPE_VOID; auto n_slot = n ? Frame1Slot(n, OP1_WRITE) : -1; @@ -1003,8 +1033,17 @@ const ZAMStmt ZAMCompiler::DoCall(const CallExpr* c, const NameExpr* n) break; default: - if ( indirect ) + if ( in_when ) + { + if ( indirect ) + op = OP_WHENINDCALLN_VV; + else + op = OP_WHENCALLN_V; + } + + else if ( indirect ) op = n ? OP_INDCALLN_VV : OP_INDCALLN_V; + else op = n ? OP_CALLN_V : OP_CALLN_X; break; @@ -1012,7 +1051,9 @@ const ZAMStmt ZAMCompiler::DoCall(const CallExpr* c, const NameExpr* n) if ( n ) { - op = AssignmentFlavor(op, nt); + if ( ! in_when ) + op = AssignmentFlavor(op, nt); + auto n_slot = Frame1Slot(n, OP1_WRITE); if ( indirect ) diff --git a/src/script_opt/ZAM/Ops.in b/src/script_opt/ZAM/Ops.in index f628e85f05..886e2b7ba6 100644 --- a/src/script_opt/ZAM/Ops.in +++ b/src/script_opt/ZAM/Ops.in @@ -947,11 +947,19 @@ eval EvalIndexVecIntSelect(z.c, frame[z.v2]) op Index type VVL -custom-method return CompileIndex(n1, n2, l); +custom-method return CompileIndex(n1, n2, l, false); op Index type VCL -custom-method return CompileIndex(n, c, l); +custom-method return CompileIndex(n, c, l, false); + +op WhenIndex +type VVL +custom-method return CompileIndex(n1, n2, l, true); + +op WhenIndex +type VCL +custom-method return CompileIndex(n, c, l, true); internal-op Index-Vec type VVV @@ -988,22 +996,51 @@ internal-op Index-Any-VecC type VVV eval EvalIndexAnyVec(z.v3) -internal-op Index-Vec-Slice -type VV -eval auto vec = frame[z.v2].vector_val; +macro WhenIndexResCheck() + auto& res = frame[z.v1].vector_val; + if ( res && IndexExprWhen::evaluating > 0 ) + IndexExprWhen::results.push_back({NewRef{}, res}); + +internal-op When-Index-Vec +type VVV +eval EvalIndexAnyVec(frame[z.v3].uint_val) + WhenIndexResCheck() + +internal-op When-Index-VecC +type VVV +eval EvalIndexAnyVec(z.v3) + WhenIndexResCheck() + +macro EvalVecSlice() + auto vec = frame[z.v2].vector_val; auto lv = z.aux->ToListVal(frame); auto v = index_slice(vec, lv.get()); Unref(frame[z.v1].vector_val); frame[z.v1].vector_val = v.release(); +internal-op Index-Vec-Slice +type VV +eval EvalVecSlice() + +internal-op When-Index-Vec-Slice +type VV +eval EvalVecSlice() + WhenIndexResCheck() + internal-op Table-Index type VV eval EvalTableIndex(z.aux->ToListVal(frame)) AssignV1(BuildVal(v, z.t)) +internal-op When-Table-Index +type VV +eval EvalTableIndex(z.aux->ToListVal(frame)) + if ( IndexExprWhen::evaluating > 0 ) + IndexExprWhen::results.emplace_back(v); + AssignV1(BuildVal(v, z.t)) + macro EvalTableIndex(index) - auto v2 = index; - auto v = frame[z.v2].table_val->FindOrDefault(v2); + auto v = frame[z.v2].table_val->FindOrDefault(index); if ( ! v ) { ZAM_run_time_error(z.loc, "no such index"); @@ -1536,6 +1573,49 @@ assign-val v indirect-call num-call-args n +# A call made in a "when" context. These always have assignment targets. +# To keep things simple, we just use one generic flavor (for N arguments, +# doing a less-streamlined-but-simpler Val-based assignment). +macro WhenCall(func) + if ( ! func ) + throw ZAMDelayedCallException(); + auto& lhs = frame[z.v1]; + auto trigger = f->GetTrigger(); + Val* v = trigger ? trigger->Lookup(z.call_expr) : nullptr; + ValPtr vp; + if ( v ) + vp = {NewRef{}, v}; + else + { + auto aux = z.aux; + auto current_assoc = f->GetTriggerAssoc(); + auto n = aux->n; + std::vector args; + for ( auto i = 0; i < n; ++i ) + args.push_back(aux->ToVal(frame, i)); + f->SetCall(z.call_expr); + vp = func->Invoke(&args, f); + f->SetTriggerAssoc(current_assoc); + if ( ! vp ) + throw ZAMDelayedCallException(); + } + if ( z.is_managed ) + ZVal::DeleteManagedType(lhs); + lhs = ZVal(vp, z.t); + +internal-op WhenCallN +type V +side-effects +eval WhenCall(z.func) + +internal-op WhenIndCallN +type VV +side-effects +eval auto sel = z.v2; + auto func = (sel < 0) ? z.aux->id_val->GetVal()->AsFunc() : frame[sel].AsFunc(); + WhenCall(func) + + ########## Statements ########## macro EvalScheduleArgs(time, is_delta, build_args) @@ -1965,6 +2045,35 @@ eval auto tt = cast_intrusive(z.t); Unref(frame[z.v1].table_val); frame[z.v1].table_val = t; +op When +type V +op1-read +eval BuildWhen(-1.0) + +op When-Timeout +type VV +op1-read +eval BuildWhen(frame[z.v2].double_val) + +op When-Timeout +type VC +op1-read +eval BuildWhen(z.c.double_val) + +macro BuildWhen(timeout) + auto& aux = z.aux; + auto wi = aux->wi; + FuncPtr func{NewRef{}, frame[z.v1].func_val}; + auto lambda = make_intrusive(func); + wi->Instantiate(lambda); + std::vector local_aggrs; + for ( int i = 0; i < aux->n; ++i ) + { + auto v = aux->ToVal(frame, i); + if ( v ) + local_aggrs.push_back(v); + } + new trigger::Trigger(wi, timeout, wi->WhenExprGlobals(), local_aggrs, f, z.loc); ######################################## # Internal diff --git a/src/script_opt/ZAM/Stmt.cc b/src/script_opt/ZAM/Stmt.cc index 4d8e385c14..a798347bd3 100644 --- a/src/script_opt/ZAM/Stmt.cc +++ b/src/script_opt/ZAM/Stmt.cc @@ -60,6 +60,9 @@ const ZAMStmt ZAMCompiler::CompileStmt(const Stmt* s) case STMT_INIT: return CompileInit(static_cast(s)); + case STMT_WHEN: + return CompileWhen(static_cast(s)); + case STMT_NULL: return EmptyStmt(); @@ -1093,6 +1096,60 @@ const ZAMStmt ZAMCompiler::CompileInit(const InitStmt* is) return last; } +const ZAMStmt ZAMCompiler::CompileWhen(const WhenStmt* ws) + { + auto wi = ws->Info(); + auto timeout = wi->TimeoutExpr(); + + auto lambda = NewSlot(true); + (void)BuildLambda(lambda, wi->Lambda().get()); + + std::vector local_aggr_slots; + for ( auto& l : wi->WhenExprLocals() ) + if ( IsAggr(l->GetType()->Tag()) ) + local_aggr_slots.push_back(l); + + int n = local_aggr_slots.size(); + auto aux = new ZInstAux(n); + aux->wi = wi; + + for ( auto i = 0; i < n; ++i ) + { + auto la = local_aggr_slots[i]; + aux->Add(i, FrameSlot(la), la->GetType()); + } + + ZInstI z; + + if ( timeout ) + { + if ( timeout->Tag() == EXPR_NAME ) + { + auto ns = FrameSlot(timeout->AsNameExpr()); + z = ZInstI(OP_WHEN_TIMEOUT_VV, lambda, ns); + } + else + { + ASSERT(timeout->Tag() == EXPR_CONST); + z = ZInstI(OP_WHEN_TIMEOUT_VC, lambda, timeout->AsConstExpr()); + } + } + + else + z = ZInstI(OP_WHEN_V, lambda); + + z.aux = aux; + + if ( ws->IsReturn() ) + { + (void)AddInst(z); + z = ZInstI(OP_RETURN_C); + z.c = ZVal(); + } + + return AddInst(z); + } + const ZAMStmt ZAMCompiler::InitRecord(IDPtr id, RecordType* rt) { auto z = ZInstI(OP_INIT_RECORD_V, FrameSlot(id)); diff --git a/src/script_opt/ZAM/ZBody.cc b/src/script_opt/ZAM/ZBody.cc index fcb831b351..b9678e06c2 100644 --- a/src/script_opt/ZAM/ZBody.cc +++ b/src/script_opt/ZAM/ZBody.cc @@ -31,6 +31,11 @@ namespace zeek::detail using std::vector; +// Thrown when a call inside a "when" delays. +class ZAMDelayedCallException : public InterpreterException + { + }; + static bool did_init = false; // Count of how often each type of ZOP executed, and how much CPU it