diff --git a/src/Stmt.h b/src/Stmt.h index b6161f2fd3..03a0a54726 100644 --- a/src/Stmt.h +++ b/src/Stmt.h @@ -191,6 +191,7 @@ public: protected: friend class ZAMCompiler; + friend class CPPCompile; int DefaultCaseIndex() const { return default_case_idx; } const auto& ValueMap() const { return case_label_value_map; } diff --git a/src/script_opt/CPP/Compile.h b/src/script_opt/CPP/Compile.h index 11d8808a75..0489be6e83 100644 --- a/src/script_opt/CPP/Compile.h +++ b/src/script_opt/CPP/Compile.h @@ -718,7 +718,11 @@ private: void GenAddStmt(const ExprStmt* es); void GenDeleteStmt(const ExprStmt* es); void GenEventStmt(const EventStmt* ev); + void GenSwitchStmt(const SwitchStmt* sw); + void GenTypeSwitchStmt(const Expr* e, const case_list* cases); + void GenTypeSwitchCase(const ID* id, int case_offset, bool is_multi); + void GenValueSwitchStmt(const Expr* e, const case_list* cases); void GenForStmt(const ForStmt* f); void GenForOverTable(const ExprPtr& tbl, const IDPtr& value_var, const IDPList* loop_vars); diff --git a/src/script_opt/CPP/README.md b/src/script_opt/CPP/README.md index 0187be4397..f34dbd079e 100644 --- a/src/script_opt/CPP/README.md +++ b/src/script_opt/CPP/README.md @@ -176,9 +176,6 @@ an extensible record (i.e., fields added using `redef`). * The compiler will not compile bodies that include "when" statements This is fairly involved to fix. -* The compiler will not compile bodies that include "type" switches. -This is not hard to fix. - * If a lambda generates an event that is not otherwise referred to, that event will not be registered upon instantiating the lambda. This is not particularly difficult to fix. diff --git a/src/script_opt/CPP/Stmts.cc b/src/script_opt/CPP/Stmts.cc index f62e72b4cb..1e3c89c460 100644 --- a/src/script_opt/CPP/Stmts.cc +++ b/src/script_opt/CPP/Stmts.cc @@ -233,6 +233,104 @@ void CPPCompile::GenSwitchStmt(const SwitchStmt* sw) auto e = sw->StmtExpr(); auto cases = sw->Cases(); + if ( sw->TypeMap()->empty() ) + GenValueSwitchStmt(e, cases); + else + GenTypeSwitchStmt(e, cases); + } + +void CPPCompile::GenTypeSwitchStmt(const Expr* e, const case_list* cases) + { + // Start a scoping block so we avoid naming conflicts if a function + // has multiple type switches. + Emit("{"); + Emit("static std::vector CPP__switch_types ="); + StartBlock(); + + for ( const auto& c : *cases ) + { + auto tc = c->TypeCases(); + if ( tc ) + for ( auto id : *tc ) + Emit(Fmt(TypeOffset(id->GetType())) + ","); + } + EndBlock(true); + + NL(); + + Emit("ValPtr CPP__sw_val = %s;", GenExpr(e, GEN_VAL_PTR)); + Emit("auto& CPP__sw_val_t = CPP__sw_val->GetType();"); + Emit("int CPP__sw_type_ind = 0;"); + + Emit("for ( auto CPP__st : CPP__switch_types )"); + StartBlock(); + Emit("if ( can_cast_value_to_type(CPP__sw_val.get(), CPP__Type__[CPP__st].get()) )"); + Emit("\tbreak;"); + Emit("++CPP__sw_type_ind;"); + EndBlock(); + + Emit("switch ( CPP__sw_type_ind ) {"); + + ++break_level; + + int case_offset = 0; + + for ( const auto& c : *cases ) + { + auto tc = c->TypeCases(); + if ( tc ) + { + bool is_multi = tc->size() > 1; + for ( auto id : *tc ) + GenTypeSwitchCase(id, case_offset++, is_multi); + } + else + Emit("default:"); + + StartBlock(); + GenStmt(c->Body()); + EndBlock(); + } + + --break_level; + + Emit("}"); // end the switch + Emit("}"); // end the scoping block + } + +void CPPCompile::GenTypeSwitchCase(const ID* id, int case_offset, bool is_multi) + { + Emit("case %s:", Fmt(case_offset)); + + if ( ! id->Name() ) + // No assignment, we're done. + return; + + // It's an assignment case. If it's a collection of multiple cases, + // assign to the variable only for this particular case. + IndentUp(); + + if ( is_multi ) + { + Emit("if ( CPP__sw_type_ind == %s )", Fmt(case_offset)); + IndentUp(); + } + + auto targ_val = "CPP__sw_val.get()"; + auto targ_type = string("CPP__Type__[CPP__switch_types[") + Fmt(case_offset) + "]].get()"; + + auto cast = string("cast_value_to_type(") + targ_val + ", " + targ_type + ")"; + + Emit("%s = %s;", LocalName(id), GenericValPtrToGT(cast, id->GetType(), GEN_NATIVE)); + + IndentDown(); + + if ( is_multi ) + IndentDown(); + } + +void CPPCompile::GenValueSwitchStmt(const Expr* e, const case_list* cases) + { auto e_it = e->GetType()->InternalType(); bool is_int = e_it == TYPE_INTERNAL_INT; bool is_uint = e_it == TYPE_INTERNAL_UNSIGNED; diff --git a/src/script_opt/CPP/Util.cc b/src/script_opt/CPP/Util.cc index 8c6e6d0a91..1afbb0c4a3 100644 --- a/src/script_opt/CPP/Util.cc +++ b/src/script_opt/CPP/Util.cc @@ -45,13 +45,6 @@ bool is_CPP_compilable(const ProfileFunc* pf, const char** reason) return false; } - if ( pf->TypeSwitches().size() > 0 ) - { - if ( reason ) - *reason = "use of type-based \"switch\""; - return false; - } - auto body = pf->ProfiledBody(); if ( body && ! body->GetOptInfo()->is_free_of_conditionals ) {