diff --git a/src/Stmt.cc b/src/Stmt.cc index e8c4abc9f9..47e839e189 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -1570,10 +1570,10 @@ TraversalCode NullStmt::Traverse(TraversalCallback* cb) const { HANDLE_TC_STMT_POST(tc); } -AssertStmt::AssertStmt(ExprPtr arg_cond, ExprPtr arg_msg) - : Stmt(STMT_ASSERT), cond(std::move(arg_cond)), msg(std::move(arg_msg)) { - if ( ! IsBool(cond->GetType()->Tag()) ) - cond->Error("conditional must be boolean"); +AssertStmt::AssertStmt(ExprPtr cond, ExprPtr arg_msg) + : ExprStmt(STMT_ASSERT, std::move(cond)), msg(std::move(arg_msg)) { + if ( ! IsBool(e->GetType()->Tag()) ) + e->Error("conditional must be boolean"); if ( msg && ! IsString(msg->GetType()->Tag()) ) msg->Error("message must be string"); @@ -1581,7 +1581,7 @@ AssertStmt::AssertStmt(ExprPtr arg_cond, ExprPtr arg_msg) zeek::ODesc desc; desc.SetShort(true); desc.SetQuotes(true); - cond->Describe(&desc); + e->Describe(&desc); cond_desc = desc.Description(); } @@ -1592,7 +1592,7 @@ ValPtr AssertStmt::Exec(Frame* f, StmtFlowType& flow) { static auto assertion_result_hook = id::find_func("assertion_result"); bool run_result_hook = assertion_result_hook && assertion_result_hook->HasEnabledBodies(); - auto assert_result = cond->Eval(f)->AsBool(); + auto assert_result = e->Eval(f)->AsBool(); if ( ! assert_result || run_result_hook ) { zeek::StringValPtr msg_val = zeek::val_mgr->EmptyString(); @@ -1619,7 +1619,13 @@ void AssertStmt::StmtDescribe(ODesc* d) const { auto orig_quotes = d->WantQuotes(); d->SetQuotes(true); - cond->Describe(d); + e->Describe(d); + + if ( msg_setup_stmt ) { + d->Add("{ "); + msg_setup_stmt->Describe(d); + d->Add(" }"); + } if ( msg ) { d->Add(","); @@ -1636,9 +1642,14 @@ TraversalCode AssertStmt::Traverse(TraversalCallback* cb) const { TraversalCode tc = cb->PreStmt(this); HANDLE_TC_STMT_PRE(tc); - tc = cond->Traverse(cb); + tc = e->Traverse(cb); HANDLE_TC_STMT_PRE(tc); if ( msg ) { + if ( msg_setup_stmt ) { + tc = msg_setup_stmt->Traverse(cb); + HANDLE_TC_STMT_PRE(tc); + } + tc = msg->Traverse(cb); HANDLE_TC_STMT_PRE(tc); } diff --git a/src/Stmt.h b/src/Stmt.h index d714a61d56..9fac25752e 100644 --- a/src/Stmt.h +++ b/src/Stmt.h @@ -495,15 +495,15 @@ private: bool is_directive; }; -class AssertStmt final : public Stmt { +class AssertStmt final : public ExprStmt { public: explicit AssertStmt(ExprPtr cond, ExprPtr msg = nullptr); ValPtr Exec(Frame* f, StmtFlowType& flow) override; - const auto& Cond() const { return cond; } const auto& CondDesc() const { return cond_desc; } const auto& Msg() const { return msg; } + const auto& MsgSetupStmt() const { return msg_setup_stmt; } void StmtDescribe(ODesc* d) const override; @@ -516,9 +516,12 @@ public: StmtPtr DoReduce(Reducer* c) override; private: - ExprPtr cond; std::string cond_desc; ExprPtr msg; + + // Statement to execute before evaluating "msg". Only used for script + // optimization. + StmtPtr msg_setup_stmt; }; // Helper function for reporting on asserts that either failed, or should diff --git a/src/script_opt/CPP/Stmts.cc b/src/script_opt/CPP/Stmts.cc index fadb8e989a..ace13e4e24 100644 --- a/src/script_opt/CPP/Stmts.cc +++ b/src/script_opt/CPP/Stmts.cc @@ -480,7 +480,7 @@ void CPPCompile::GenForOverString(const ExprPtr& str, const IDPList* loop_vars) } void CPPCompile::GenAssertStmt(const AssertStmt* a) { - auto& cond = a->Cond(); + auto cond = a->StmtExpr(); auto& msg = a->Msg(); Emit("{ // begin a new scope for internal \"assert\" variables"); diff --git a/src/script_opt/Stmt.cc b/src/script_opt/Stmt.cc index 8ef72c60b4..8c1f7d4bf0 100644 --- a/src/script_opt/Stmt.cc +++ b/src/script_opt/Stmt.cc @@ -1070,11 +1070,36 @@ StmtPtr InitStmt::DoReduce(Reducer* c) { return ThisPtr(); } -StmtPtr AssertStmt::Duplicate() { return SetSucc(new AssertStmt(cond->Duplicate(), msg ? msg->Duplicate() : nullptr)); } +StmtPtr AssertStmt::Duplicate() { return SetSucc(new AssertStmt(e->Duplicate(), msg ? msg->Duplicate() : nullptr)); } -bool AssertStmt::IsReduced(Reducer* c) const { return false; } +bool AssertStmt::IsReduced(Reducer* c) const { + if ( ! analysis_options.keep_asserts ) + return false; -StmtPtr AssertStmt::DoReduce(Reducer* c) { return TransformMe(make_intrusive(), c); } + return e->IsSingleton(c) && (! msg || msg->IsSingleton(c)); +} + +StmtPtr AssertStmt::DoReduce(Reducer* c) { + if ( ! analysis_options.keep_asserts ) + return TransformMe(make_intrusive(), c); + + if ( c->Optimizing() ) { + e = c->OptExpr(e); + if ( msg ) + msg = c->OptExpr(msg); + return ThisPtr(); + } + else if ( IsReduced(c) ) + return ThisPtr(); + + StmtPtr red_stmt; + e = e->ReduceToSingleton(c, red_stmt); + if ( msg ) + msg = msg->ReduceToSingleton(c, msg_setup_stmt); + + auto sl = with_location_of(make_intrusive(red_stmt, ThisPtr()), this); + return sl->Reduce(c); +} bool WhenInfo::HasUnreducedIDs(Reducer* c) const { for ( auto& cp : *cl ) { diff --git a/src/script_opt/UseDefs.cc b/src/script_opt/UseDefs.cc index e0edc9a85b..ad238e6dcb 100644 --- a/src/script_opt/UseDefs.cc +++ b/src/script_opt/UseDefs.cc @@ -287,6 +287,21 @@ UDs UseDefs::PropagateUDs(const Stmt* s, UDs succ_UDs, const Stmt* succ_stmt, bo return CreateUDs(s, uds); } + case STMT_ASSERT: { + auto a = s->AsAssertStmt(); + auto e = a->StmtExpr(); + + if ( auto msg = a->Msg().get() ) { + succ_UDs = UD_Union(succ_UDs, ExprUDs(msg)); + if ( auto msg_setup_stmt = a->MsgSetupStmt().get() ) { + succ_UDs = PropagateUDs(msg_setup_stmt, succ_UDs, succ_stmt, second_pass); + succ_stmt = msg_setup_stmt; + } + } + + return CreateUDs(s, UD_Union(succ_UDs, ExprUDs(e))); + } + case STMT_SWITCH: { auto sw_UDs = std::make_shared(); diff --git a/src/script_opt/ZAM/Stmt.cc b/src/script_opt/ZAM/Stmt.cc index 65a9d94e13..75075dfa86 100644 --- a/src/script_opt/ZAM/Stmt.cc +++ b/src/script_opt/ZAM/Stmt.cc @@ -48,6 +48,8 @@ const ZAMStmt ZAMCompiler::CompileStmt(const Stmt* s) { case STMT_WHEN: return CompileWhen(static_cast(s)); + case STMT_ASSERT: return CompileAssert(static_cast(s)); + case STMT_NULL: return EmptyStmt(); case STMT_CHECK_ANY_LEN: { @@ -1047,6 +1049,45 @@ const ZAMStmt ZAMCompiler::CompileWhen(const WhenStmt* ws) { return AddInst(z); } +const ZAMStmt ZAMCompiler::CompileAssert(const AssertStmt* as) { + auto cond = as->StmtExpr(); + + int cond_slot; + if ( cond->Tag() == EXPR_CONST ) + cond_slot = TempForConst(cond->AsConstExpr()); + else + cond_slot = FrameSlot(cond->AsNameExpr()); + + auto decision_slot = NewSlot(false); + + (void)AddInst(ZInstI(OP_SHOULD_REPORT_ASSERT_VV, decision_slot, cond_slot)); + + ZInstI z; + + // We don't have a convenient way of directly introducing a std::string + // constant, so we build one to hold it. + auto cond_desc = make_intrusive(new String(as->CondDesc())); + auto cond_desc_e = make_intrusive(cond_desc); + + if ( auto msg = as->Msg() ) { + auto& msg_setup_stmt = as->MsgSetupStmt(); + if ( msg_setup_stmt ) + (void)CompileStmt(msg_setup_stmt); + + int msg_slot; + if ( msg->Tag() == EXPR_CONST ) + msg_slot = TempForConst(msg->AsConstExpr()); + else + msg_slot = FrameSlot(msg->AsNameExpr()); + + z = ZInstI(OP_REPORT_ASSERT_WITH_MESSAGE_VVVC, decision_slot, cond_slot, msg_slot, cond_desc_e.get()); + } + else + z = ZInstI(OP_REPORT_ASSERT_VVC, decision_slot, cond_slot, cond_desc_e.get()); + + return AddInst(z); +} + const ZAMStmt ZAMCompiler::InitRecord(IDPtr id, RecordType* rt) { auto z = ZInstI(OP_INIT_RECORD_V, FrameSlot(id)); z.SetType({NewRef{}, rt}); diff --git a/src/script_opt/ZAM/Stmt.h b/src/script_opt/ZAM/Stmt.h index 8a85ed4539..242fbe8197 100644 --- a/src/script_opt/ZAM/Stmt.h +++ b/src/script_opt/ZAM/Stmt.h @@ -20,6 +20,7 @@ const ZAMStmt CompileCatchReturn(const CatchReturnStmt* cr); const ZAMStmt CompileStmts(const StmtList* sl); const ZAMStmt CompileInit(const InitStmt* is); const ZAMStmt CompileWhen(const WhenStmt* ws); +const ZAMStmt CompileAssert(const AssertStmt* ws); const ZAMStmt CompileNext() { return GenGoTo(nexts.back()); } const ZAMStmt CompileBreak() { return GenGoTo(breaks.back()); } diff --git a/src/script_opt/ZAM/Vars.h b/src/script_opt/ZAM/Vars.h index 5fe46e9fe4..affc88ec69 100644 --- a/src/script_opt/ZAM/Vars.h +++ b/src/script_opt/ZAM/Vars.h @@ -25,8 +25,8 @@ int FrameSlotIfName(const Expr* e) { return n ? FrameSlot(n->Id()) : -1; } -int FrameSlot(const NameExpr* id) { return FrameSlot(id->AsNameExpr()->Id()); } -int Frame1Slot(const NameExpr* id, ZOp op) { return Frame1Slot(id->AsNameExpr()->Id(), op); } +int FrameSlot(const NameExpr* n) { return FrameSlot(n->Id()); } +int Frame1Slot(const NameExpr* n, ZOp op) { return Frame1Slot(n->Id(), op); } int Frame1Slot(const ID* id, ZOp op) { return Frame1Slot(id, op1_flavor[op]); } int Frame1Slot(const NameExpr* n, ZAMOp1Flavor fl) { return Frame1Slot(n->Id(), fl); }