diff --git a/src/Expr.cc b/src/Expr.cc index ed14eb6045..d3b22503cb 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -856,32 +856,55 @@ Val* BinaryExpr::SetFold(Val* v1, Val* v2) const { TableVal* tv1 = v1->AsTableVal(); TableVal* tv2 = v2->AsTableVal(); + TableVal* result; + bool res; - if ( tag != EXPR_AND && tag != EXPR_OR && tag != EXPR_SUB ) - BadTag("BinaryExpr::SetFold"); - - if ( tag == EXPR_AND ) + switch ( tag ) { + case EXPR_AND: return tv1->Intersect(tv2); - // TableVal* result = new TableVal(v1->Type()->AsTableType()); - TableVal* result = v1->Clone()->AsTableVal(); + case EXPR_OR: + // TableVal* result = new TableVal(v1->Type()->AsTableType()); + result = v1->Clone()->AsTableVal(); - if ( tag == EXPR_OR ) - { if ( ! tv2->AddTo(result, false, false) ) reporter->InternalError("set union failed to type check"); - } + return result; + + case EXPR_SUB: + result = v1->Clone()->AsTableVal(); - else if ( tag == EXPR_SUB ) - { if ( ! tv2->RemoveFrom(result) ) reporter->InternalError("set difference failed to type check"); - } + return result; - else + case EXPR_EQ: + res = tv1->EqualTo(tv2); + break; + + case EXPR_NE: + res = ! tv1->EqualTo(tv2); + break; + + case EXPR_LT: + res = tv1->IsSubsetOf(tv2) && tv1->Size() < tv2->Size(); + break; + + case EXPR_LE: + res = tv1->IsSubsetOf(tv2); + break; + + case EXPR_GE: + case EXPR_GT: + // These should't happen due to canonicalization. + reporter->InternalError("confusion over canonicalization in set comparison"); + + default: BadTag("BinaryExpr::SetFold", expr_name(tag)); + return 0; + } - return result; + return new Val(res, TYPE_BOOL); } Val* BinaryExpr::AddrFold(Val* v1, Val* v2) const @@ -1970,13 +1993,16 @@ EqExpr::EqExpr(BroExprTag arg_tag, Expr* arg_op1, Expr* arg_op2) Canonicize(); - TypeTag bt1 = op1->Type()->Tag(); - if ( IsVector(bt1) ) - bt1 = op1->Type()->AsVectorType()->YieldType()->Tag(); + const BroType* t1 = op1->Type(); + const BroType* t2 = op2->Type(); - TypeTag bt2 = op2->Type()->Tag(); + TypeTag bt1 = t1->Tag(); + if ( IsVector(bt1) ) + bt1 = t1->AsVectorType()->YieldType()->Tag(); + + TypeTag bt2 = t2->Tag(); if ( IsVector(bt2) ) - bt2 = op2->Type()->AsVectorType()->YieldType()->Tag(); + bt2 = t2->AsVectorType()->YieldType()->Tag(); if ( is_vector(op1) || is_vector(op2) ) SetType(new VectorType(base_type(TYPE_BOOL))); @@ -2006,10 +2032,20 @@ EqExpr::EqExpr(BroExprTag arg_tag, Expr* arg_op1, Expr* arg_op2) break; case TYPE_ENUM: - if ( ! same_type(op1->Type(), op2->Type()) ) + if ( ! same_type(t1, t2) ) ExprError("illegal enum comparison"); break; + case TYPE_TABLE: + if ( t1->IsSet() && t2->IsSet() ) + { + if ( ! same_type(t1, t2) ) + ExprError("incompatible sets in comparison"); + break; + } + + // FALL THROUGH + default: ExprError("illegal comparison"); } @@ -2072,13 +2108,16 @@ RelExpr::RelExpr(BroExprTag arg_tag, Expr* arg_op1, Expr* arg_op2) Canonicize(); - TypeTag bt1 = op1->Type()->Tag(); - if ( IsVector(bt1) ) - bt1 = op1->Type()->AsVectorType()->YieldType()->Tag(); + const BroType* t1 = op1->Type(); + const BroType* t2 = op2->Type(); - TypeTag bt2 = op2->Type()->Tag(); + TypeTag bt1 = t1->Tag(); + if ( IsVector(bt1) ) + bt1 = t1->AsVectorType()->YieldType()->Tag(); + + TypeTag bt2 = t2->Tag(); if ( IsVector(bt2) ) - bt2 = op2->Type()->AsVectorType()->YieldType()->Tag(); + bt2 = t2->AsVectorType()->YieldType()->Tag(); if ( is_vector(op1) || is_vector(op2) ) SetType(new VectorType(base_type(TYPE_BOOL))); @@ -2088,6 +2127,12 @@ RelExpr::RelExpr(BroExprTag arg_tag, Expr* arg_op1, Expr* arg_op2) if ( BothArithmetic(bt1, bt2) ) PromoteOps(max_type(bt1, bt2)); + else if ( t1->IsSet() && t2->IsSet() ) + { + if ( ! same_type(t1, t2) ) + ExprError("incompatible sets in comparison"); + } + else if ( bt1 != bt2 ) ExprError("operands must be of the same type"); diff --git a/src/Val.cc b/src/Val.cc index 48458254f1..02e88cd3b4 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -1727,16 +1727,16 @@ TableVal* TableVal::Intersect(const TableVal* tv) const { TableVal* result = new TableVal(table_type); + const PDict(TableEntryVal)* t0 = AsTable(); const PDict(TableEntryVal)* t1 = tv->AsTable(); - const PDict(TableEntryVal)* t2 = AsTable(); - PDict(TableEntryVal)* t3 = result->AsNonConstTable(); + PDict(TableEntryVal)* t2 = result->AsNonConstTable(); - // Figure out which is smaller. - if ( t1->Length() > t2->Length() ) + // Figure out which is smaller; assign it to t1. + if ( t1->Length() > t0->Length() ) { // Swap. const PDict(TableEntryVal)* tmp = t1; - t1 = t2; - t2 = tmp; + t1 = t0; + t0 = tmp; } IterCookie* c = t1->InitForIteration(); @@ -1745,8 +1745,8 @@ TableVal* TableVal::Intersect(const TableVal* tv) const { // Here we leverage the same assumption about consistent // hashes as in TableVal::RemoveFrom above. - if ( t2->Lookup(k) ) - t3->Insert(k, new TableEntryVal(0)); + if ( t0->Lookup(k) ) + t2->Insert(k, new TableEntryVal(0)); delete k; } @@ -1754,6 +1754,58 @@ TableVal* TableVal::Intersect(const TableVal* tv) const return result; } +bool TableVal::EqualTo(const TableVal* tv) const + { + const PDict(TableEntryVal)* t0 = AsTable(); + const PDict(TableEntryVal)* t1 = tv->AsTable(); + + if ( t0->Length() != t1->Length() ) + return false; + + IterCookie* c = t0->InitForIteration(); + HashKey* k; + while ( t0->NextEntry(k, c) ) + { + // Here we leverage the same assumption about consistent + // hashes as in TableVal::RemoveFrom above. + if ( ! t1->Lookup(k) ) + { + delete k; + return false; + } + + delete k; + } + + return true; + } + +bool TableVal::IsSubsetOf(const TableVal* tv) const + { + const PDict(TableEntryVal)* t0 = AsTable(); + const PDict(TableEntryVal)* t1 = tv->AsTable(); + + if ( t0->Length() > t1->Length() ) + return false; + + IterCookie* c = t0->InitForIteration(); + HashKey* k; + while ( t0->NextEntry(k, c) ) + { + // Here we leverage the same assumption about consistent + // hashes as in TableVal::RemoveFrom above. + if ( ! t1->Lookup(k) ) + { + delete k; + return false; + } + + delete k; + } + + return true; + } + int TableVal::ExpandAndInit(Val* index, Val* new_val) { BroType* index_type = index->Type(); diff --git a/src/Val.h b/src/Val.h index ef2a8eefd6..32ce6a0187 100644 --- a/src/Val.h +++ b/src/Val.h @@ -815,6 +815,16 @@ public: // sense for sets. TableVal* Intersect(const TableVal* v) const; + // Returns true if this set contains the same members as the + // given set. Note that comparisons are done using hash keys, + // so errors can arise for compound sets such as sets-of-sets. + // See https://bro-tracker.atlassian.net/browse/BIT-1949. + bool EqualTo(const TableVal* v) const; + + // Returns true if this set is a subset (not necessarily proper) + // of the given set. + bool IsSubsetOf(const TableVal* v) const; + // Expands any lists in the index into multiple initializations. // Returns true if the initializations typecheck, false if not. int ExpandAndInit(Val* index, Val* new_val);