From be71a42f4c9cdde69b74f18203db062dbc18dea2 Mon Sep 17 00:00:00 2001 From: Jon Siwek Date: Wed, 16 Jan 2013 16:17:17 -0600 Subject: [PATCH] Add "fallthrough" keyword, require a flow statement to end case blocks. Case blocks in switch statements now must end in a break, return, or fallthrough statement to give best mix of safety, readability, and flexibility. The new fallthrough keyword explicitly allows control to be passed to the next case block in a switch statement. Addresses #754. --- src/SerialTypes.h | 1 + src/Stmt.cc | 83 +++++++++++++++++-- src/Stmt.h | 18 +++- src/StmtEnums.h | 4 +- src/parse.y | 12 ++- src/scan.l | 1 + testing/btest/core/leaks/switch-statement.bro | 32 ++++--- testing/btest/language/switch-statement.bro | 32 ++++--- 8 files changed, 150 insertions(+), 33 deletions(-) diff --git a/src/SerialTypes.h b/src/SerialTypes.h index e103c1c40e..723badab1e 100644 --- a/src/SerialTypes.h +++ b/src/SerialTypes.h @@ -171,6 +171,7 @@ SERIAL_STMT(EVENT_BODY_LIST, 16) SERIAL_STMT(INIT_STMT, 17) SERIAL_STMT(NULL_STMT, 18) SERIAL_STMT(WHEN_STMT, 19) +SERIAL_STMT(FALLTHROUGH_STMT, 20) #define SERIAL_TYPE(name, val) SERIAL_CONST(name, val, BRO_TYPE) SERIAL_TYPE(BRO_TYPE, 1) diff --git a/src/Stmt.cc b/src/Stmt.cc index 3e37256338..cc506db985 100644 --- a/src/Stmt.cc +++ b/src/Stmt.cc @@ -23,7 +23,7 @@ const char* stmt_name(BroStmtTag t) "print", "event", "expr", "if", "when", "switch", "for", "next", "break", "return", "add", "delete", "list", "bodylist", - "", + "", "fallthrough", "null", }; @@ -584,6 +584,29 @@ bool IfStmt::DoUnserialize(UnserialInfo* info) return s2 != 0; } +static BroStmtTag get_last_stmt_tag(const Stmt* stmt) + { + if ( ! stmt ) return STMT_NULL; + + if ( stmt->Tag() != STMT_LIST ) return stmt->Tag(); + + const StmtList* stmts = stmt->AsStmtList(); + int len = stmts->Stmts().length(); + + if ( len == 0 ) return STMT_LIST; + + return get_last_stmt_tag(stmts->Stmts()[len - 1]); + } + +Case::Case(ListExpr* c, Stmt* arg_s) + : cases(simplify_expr_list(c, SIMPLIFY_GENERAL)), s(arg_s) + { + BroStmtTag t = get_last_stmt_tag(Body()); + + if ( t != STMT_BREAK && t != STMT_FALLTHROUGH && t != STMT_RETURN ) + Error("case block must end in break/fallthrough/return statement"); + } + Case::~Case() { Unref(cases); @@ -701,9 +724,6 @@ SwitchStmt::SwitchStmt(Expr* index, case_list* arg_cases) : const Case* c = (*cases)[i]; const ListExpr* le = c->Cases(); - if ( ! c->Body() || c->Body()->AsStmtList()->Stmts().length() == 0 ) - c->Error("empty case label body does nothing"); - if ( le ) { if ( ! le->Type()->AsTypeList()->AllMatch(e->Type(), false) ) @@ -798,12 +818,18 @@ Val* SwitchStmt::DoExec(Frame* f, Val* v, stmt_flow_type& flow) const if ( matching_label_idx == -1 ) return 0; - const Case* c = (*cases)[matching_label_idx]; + for ( int i = matching_label_idx; i < cases->length(); ++i ) + { + const Case* c = (*cases)[i]; - flow = FLOW_NEXT; - rval = c->Body()->Exec(f, flow); + flow = FLOW_NEXT; + rval = c->Body()->Exec(f, flow); - if ( flow == FLOW_BREAK ) + if ( flow == FLOW_BREAK || flow == FLOW_RETURN ) + break; + } + + if ( flow != FLOW_RETURN ) flow = FLOW_NEXT; return rval; @@ -1461,6 +1487,47 @@ bool BreakStmt::DoUnserialize(UnserialInfo* info) return true; } +Val* FallthroughStmt::Exec(Frame* /* f */, stmt_flow_type& flow) const + { + RegisterAccess(); + flow = FLOW_FALLTHROUGH; + return 0; + } + +int FallthroughStmt::IsPure() const + { + return 1; + } + +void FallthroughStmt::Describe(ODesc* d) const + { + Stmt::Describe(d); + Stmt::DescribeDone(d); + } + +TraversalCode FallthroughStmt::Traverse(TraversalCallback* cb) const + { + TraversalCode tc = cb->PreStmt(this); + HANDLE_TC_STMT_PRE(tc); + + tc = cb->PostStmt(this); + HANDLE_TC_STMT_POST(tc); + } + +IMPLEMENT_SERIAL(FallthroughStmt, SER_FALLTHROUGH_STMT); + +bool FallthroughStmt::DoSerialize(SerialInfo* info) const + { + DO_SERIALIZE(SER_FALLTHROUGH_STMT, Stmt); + return true; + } + +bool FallthroughStmt::DoUnserialize(UnserialInfo* info) + { + DO_UNSERIALIZE(Stmt); + return true; + } + ReturnStmt::ReturnStmt(Expr* arg_e) : ExprStmt(STMT_RETURN, arg_e) { Scope* s = current_scope(); diff --git a/src/Stmt.h b/src/Stmt.h index 32be7f33fc..32b90b4190 100644 --- a/src/Stmt.h +++ b/src/Stmt.h @@ -195,8 +195,7 @@ protected: class Case : public BroObj { public: - Case(ListExpr* c, Stmt* arg_s) : - cases(simplify_expr_list(c,SIMPLIFY_GENERAL)), s(arg_s) { } + Case(ListExpr* c, Stmt* arg_s); ~Case(); const ListExpr* Cases() const { return cases; } @@ -371,6 +370,21 @@ protected: DECLARE_SERIAL(BreakStmt); }; +class FallthroughStmt : public Stmt { +public: + FallthroughStmt() : Stmt(STMT_FALLTHROUGH) { } + + Val* Exec(Frame* f, stmt_flow_type& flow) const; + int IsPure() const; + + void Describe(ODesc* d) const; + + TraversalCode Traverse(TraversalCallback* cb) const; + +protected: + DECLARE_SERIAL(FallthroughStmt); +}; + class ReturnStmt : public ExprStmt { public: ReturnStmt(Expr* e); diff --git a/src/StmtEnums.h b/src/StmtEnums.h index f431e3fea1..1114816a93 100644 --- a/src/StmtEnums.h +++ b/src/StmtEnums.h @@ -16,6 +16,7 @@ typedef enum { STMT_ADD, STMT_DELETE, STMT_LIST, STMT_EVENT_BODY_LIST, STMT_INIT, + STMT_FALLTHROUGH, STMT_NULL #define NUM_STMTS (int(STMT_NULL) + 1) } BroStmtTag; @@ -24,7 +25,8 @@ typedef enum { FLOW_NEXT, // continue on to next statement FLOW_LOOP, // go to top of loop FLOW_BREAK, // break out of loop - FLOW_RETURN // return from function + FLOW_RETURN, // return from function + FLOW_FALLTHROUGH// fall through to next switch case } stmt_flow_type; extern const char* stmt_name(BroStmtTag t); diff --git a/src/parse.y b/src/parse.y index 090786647e..7ce1174595 100644 --- a/src/parse.y +++ b/src/parse.y @@ -8,8 +8,8 @@ %token TOK_ATENDIF TOK_ATELSE TOK_ATIF TOK_ATIFDEF TOK_ATIFNDEF %token TOK_BOOL TOK_BREAK TOK_CASE TOK_CONST %token TOK_CONSTANT TOK_COPY TOK_COUNT TOK_COUNTER TOK_DEFAULT TOK_DELETE -%token TOK_DOUBLE TOK_ELSE TOK_ENUM TOK_EVENT TOK_EXPORT TOK_FILE TOK_FOR -%token TOK_FUNCTION TOK_GLOBAL TOK_HOOK TOK_ID TOK_IF TOK_INT +%token TOK_DOUBLE TOK_ELSE TOK_ENUM TOK_EVENT TOK_EXPORT TOK_FALLTHROUGH +%token TOK_FILE TOK_FOR TOK_FUNCTION TOK_GLOBAL TOK_HOOK TOK_ID TOK_IF TOK_INT %token TOK_INTERVAL TOK_LIST TOK_LOCAL TOK_MODULE %token TOK_NEXT TOK_OF TOK_OPAQUE TOK_PATTERN TOK_PATTERN_TEXT %token TOK_PORT TOK_PRINT TOK_RECORD TOK_REDEF @@ -1436,6 +1436,14 @@ stmt: brofiler.AddStmt($$); } + | TOK_FALLTHROUGH ';' opt_no_test + { + set_location(@1, @2); + $$ = new FallthroughStmt; + if ( ! $3 ) + brofiler.AddStmt($$); + } + | TOK_RETURN ';' opt_no_test { set_location(@1, @2); diff --git a/src/scan.l b/src/scan.l index efcd273e36..ffbc125728 100644 --- a/src/scan.l +++ b/src/scan.l @@ -282,6 +282,7 @@ else return TOK_ELSE; enum return TOK_ENUM; event return TOK_EVENT; export return TOK_EXPORT; +fallthrough return TOK_FALLTHROUGH; file return TOK_FILE; for return TOK_FOR; function return TOK_FUNCTION; diff --git a/testing/btest/core/leaks/switch-statement.bro b/testing/btest/core/leaks/switch-statement.bro index 6fbdb0d54a..845915ae8a 100644 --- a/testing/btest/core/leaks/switch-statement.bro +++ b/testing/btest/core/leaks/switch-statement.bro @@ -148,18 +148,19 @@ function switch_empty(v: count): string return "n/a"; } -function switch_break(v: count): string +function switch_fallthrough(v: count): string { local rval = ""; switch ( v ) { case 1: rval += "test"; + fallthrough; case 2: rval += "testing"; - break; - rval += "ERROR"; + fallthrough; case 3: rval += "tested"; + break; } return rval + "return"; } @@ -170,12 +171,16 @@ function switch_default(v: count): string switch ( v ) { case 1: rval += "1"; + fallthrough; case 2: rval += "2"; + break; case 3: rval += "3"; + fallthrough; default: rval += "d"; + break; } return rval + "r"; } @@ -186,13 +191,16 @@ function switch_default_placement(v: count): string switch ( v ) { case 1: rval += "1"; + fallthrough; default: rval += "d"; + fallthrough; case 2: rval += "2"; break; case 3: rval += "3"; + break; } return rval + "r"; } @@ -252,17 +260,17 @@ event new_connection(c: connection) test_switch( switch_subnet([fe80::1]/96) , "[fe80::0]" ); test_switch( switch_subnet(192.168.1.100/16) , "192.168.0.0/16" ); test_switch( switch_empty(2) , "n/a" ); - test_switch( switch_break(1) , "testreturn" ); - test_switch( switch_break(2) , "testingreturn" ); - test_switch( switch_break(3) , "testedreturn" ); - test_switch( switch_default(1) , "1r" ); + test_switch( switch_fallthrough(1) , "testtestingtestedreturn" ); + test_switch( switch_fallthrough(2) , "testingtestedreturn" ); + test_switch( switch_fallthrough(3) , "testedreturn" ); + test_switch( switch_default(1) , "12r" ); test_switch( switch_default(2) , "2r" ); - test_switch( switch_default(3) , "3r" ); + test_switch( switch_default(3) , "3dr" ); test_switch( switch_default(4) , "dr" ); - test_switch( switch_default_placement(1) , "1r" ); + test_switch( switch_default_placement(1) , "1d2r" ); test_switch( switch_default_placement(2) , "2r" ); test_switch( switch_default_placement(3) , "3r" ); - test_switch( switch_default_placement(4) , "dr" ); + test_switch( switch_default_placement(4) , "d2r" ); local v = vector(0,1,2,3,4,5,6,7,9,10); local expect: string; @@ -272,12 +280,16 @@ event new_connection(c: connection) switch ( v[i] ) { case 1, 2: expect = "1,2"; + break; case 3, 4, 5: expect = "3,4,5"; + break; case 6, 7, 8, 9: expect = "6,7,8,9"; + break; default: expect = "n/a"; + break; } test_switch( switch_case_list(v[i]) , expect ); } diff --git a/testing/btest/language/switch-statement.bro b/testing/btest/language/switch-statement.bro index dcf2a4c041..152b14f87d 100644 --- a/testing/btest/language/switch-statement.bro +++ b/testing/btest/language/switch-statement.bro @@ -143,18 +143,19 @@ function switch_empty(v: count): string return "n/a"; } -function switch_break(v: count): string +function switch_fallthrough(v: count): string { local rval = ""; switch ( v ) { case 1: rval += "test"; + fallthrough; case 2: rval += "testing"; - break; - rval += "ERROR"; + fallthrough; case 3: rval += "tested"; + break; } return rval + "return"; } @@ -165,12 +166,16 @@ function switch_default(v: count): string switch ( v ) { case 1: rval += "1"; + fallthrough; case 2: rval += "2"; + break; case 3: rval += "3"; + fallthrough; default: rval += "d"; + break; } return rval + "r"; } @@ -181,13 +186,16 @@ function switch_default_placement(v: count): string switch ( v ) { case 1: rval += "1"; + fallthrough; default: rval += "d"; + fallthrough; case 2: rval += "2"; break; case 3: rval += "3"; + break; } return rval + "r"; } @@ -247,17 +255,17 @@ event bro_init() test_switch( switch_subnet([fe80::1]/96) , "[fe80::0]" ); test_switch( switch_subnet(192.168.1.100/16) , "192.168.0.0/16" ); test_switch( switch_empty(2) , "n/a" ); - test_switch( switch_break(1) , "testreturn" ); - test_switch( switch_break(2) , "testingreturn" ); - test_switch( switch_break(3) , "testedreturn" ); - test_switch( switch_default(1) , "1r" ); + test_switch( switch_fallthrough(1) , "testtestingtestedreturn" ); + test_switch( switch_fallthrough(2) , "testingtestedreturn" ); + test_switch( switch_fallthrough(3) , "testedreturn" ); + test_switch( switch_default(1) , "12r" ); test_switch( switch_default(2) , "2r" ); - test_switch( switch_default(3) , "3r" ); + test_switch( switch_default(3) , "3dr" ); test_switch( switch_default(4) , "dr" ); - test_switch( switch_default_placement(1) , "1r" ); + test_switch( switch_default_placement(1) , "1d2r" ); test_switch( switch_default_placement(2) , "2r" ); test_switch( switch_default_placement(3) , "3r" ); - test_switch( switch_default_placement(4) , "dr" ); + test_switch( switch_default_placement(4) , "d2r" ); local v = vector(0,1,2,3,4,5,6,7,9,10); local expect: string; @@ -267,12 +275,16 @@ event bro_init() switch ( v[i] ) { case 1, 2: expect = "1,2"; + break; case 3, 4, 5: expect = "3,4,5"; + break; case 6, 7, 8, 9: expect = "6,7,8,9"; + break; default: expect = "n/a"; + break; } test_switch( switch_case_list(v[i]) , expect ); }