diff --git a/src/Expr.cc b/src/Expr.cc index 1ab82853c3..617a7637df 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -663,6 +663,9 @@ Val* BinaryExpr::Fold(Val* v1, Val* v2) const if ( it == TYPE_INTERNAL_STRING ) return StringFold(v1, v2); + if ( v1->Type()->IsSet() ) + return SetFold(v1, v2); + if ( it == TYPE_INTERNAL_ADDR ) return AddrFold(v1, v2); @@ -849,6 +852,32 @@ Val* BinaryExpr::StringFold(Val* v1, Val* v2) const return new Val(result, TYPE_BOOL); } +Val* BinaryExpr::SetFold(Val* v1, Val* v2) const + { + TableVal* tv1 = v1->AsTableVal(); + TableVal* tv2 = v2->AsTableVal(); + + if ( tag != EXPR_AND && tag != EXPR_OR && tag != EXPR_SUB ) + BadTag("BinaryExpr::SetFold"); + + // TableVal* result = new TableVal(v1->Type()->AsTableType()); + TableVal* result = v1->Clone()->AsTableVal(); + + if ( tag == EXPR_OR ) + { + if ( ! tv2->AddTo(result, false, false) ) + reporter->InternalError("set union failed to type check"); + } + + else if ( tag == EXPR_SUB ) + { + if ( ! tv2->RemoveFrom(result) ) + reporter->InternalError("set difference failed to type check"); + } + + return result; + } + Val* BinaryExpr::AddrFold(Val* v1, Val* v2) const { IPAddr a1 = v1->AsAddr(); @@ -1421,24 +1450,39 @@ SubExpr::SubExpr(Expr* arg_op1, Expr* arg_op2) if ( IsError() ) return; - TypeTag bt1 = op1->Type()->Tag(); - if ( IsVector(bt1) ) - bt1 = op1->Type()->AsVectorType()->YieldType()->Tag(); + const BroType* t1 = op1->Type(); + const BroType* t2 = op2->Type(); - TypeTag bt2 = op2->Type()->Tag(); + TypeTag bt1 = t1->Tag(); + if ( IsVector(bt1) ) + bt1 = t1->AsVectorType()->YieldType()->Tag(); + + TypeTag bt2 = t2->Tag(); if ( IsVector(bt2) ) - bt2 = op2->Type()->AsVectorType()->YieldType()->Tag(); + bt2 = t2->AsVectorType()->YieldType()->Tag(); BroType* base_result_type = 0; if ( bt1 == TYPE_TIME && bt2 == TYPE_INTERVAL ) base_result_type = base_type(bt1); + else if ( bt1 == TYPE_TIME && bt2 == TYPE_TIME ) SetType(base_type(TYPE_INTERVAL)); + else if ( bt1 == TYPE_INTERVAL && bt2 == TYPE_INTERVAL ) base_result_type = base_type(bt1); + + else if ( t1->IsSet() && t2->IsSet() ) + { + if ( same_type(t1, t2) ) + SetType(op1->Type()->Ref()); + else + ExprError("incompatible \"set\" operands"); + } + else if ( BothArithmetic(bt1, bt2) ) PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else ExprError("requires arithmetic operands"); @@ -1864,13 +1908,16 @@ BitExpr::BitExpr(BroExprTag arg_tag, Expr* arg_op1, Expr* arg_op2) if ( IsError() ) return; - TypeTag bt1 = op1->Type()->Tag(); - if ( IsVector(bt1) ) - bt1 = op1->Type()->AsVectorType()->YieldType()->Tag(); + const BroType* t1 = op1->Type(); + const BroType* t2 = op2->Type(); - TypeTag bt2 = op2->Type()->Tag(); + TypeTag bt1 = t1->Tag(); + if ( IsVector(bt1) ) + bt1 = t1->AsVectorType()->YieldType()->Tag(); + + TypeTag bt2 = t2->Tag(); if ( IsVector(bt2) ) - bt2 = op2->Type()->AsVectorType()->YieldType()->Tag(); + bt2 = t2->AsVectorType()->YieldType()->Tag(); if ( (bt1 == TYPE_COUNT || bt1 == TYPE_COUNTER) && (bt2 == TYPE_COUNT || bt2 == TYPE_COUNTER) ) @@ -1883,8 +1930,16 @@ BitExpr::BitExpr(BroExprTag arg_tag, Expr* arg_op1, Expr* arg_op2) SetType(base_type(TYPE_COUNT)); } + else if ( t1->IsSet() && t2->IsSet() ) + { + if ( same_type(t1, t2) ) + SetType(op1->Type()->Ref()); + else + ExprError("incompatible \"set\" operands"); + } + else - ExprError("requires \"count\" operands"); + ExprError("requires \"count\" or compatible \"set\" operands"); } IMPLEMENT_SERIAL(BitExpr, SER_BIT_EXPR); diff --git a/src/Expr.h b/src/Expr.h index 9fc9aa15ed..a8c890b675 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -329,6 +329,9 @@ protected: // Same for when the constants are strings. virtual Val* StringFold(Val* v1, Val* v2) const; + // Same for when the constants are sets. + virtual Val* SetFold(Val* v1, Val* v2) const; + // Same for when the constants are addresses or subnets. virtual Val* AddrFold(Val* v1, Val* v2) const; virtual Val* SubNetFold(Val* v1, Val* v2) const; diff --git a/src/Val.cc b/src/Val.cc index 4da4a35d48..540719ef14 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -1704,9 +1704,18 @@ int TableVal::RemoveFrom(Val* val) const HashKey* k; while ( tbl->NextEntry(k, c) ) { - Val* index = RecoverIndex(k); + // ### The following code appears to be a complete + // no-op. Commented out 8+ years after it was + // introduced. -VP 22Jun18 + // Val* index = RecoverIndex(k); + // + // Unref(index); - Unref(index); + // Not sure that this is 100% sound, since the HashKey + // comes from one table but is being used in another. + // OTOH, they are both the same type, so as long as + // we don't have hash keys that are keyed per dictionary, + // it should work ... Unref(t->Delete(k)); delete k; } @@ -1714,6 +1723,41 @@ int TableVal::RemoveFrom(Val* val) const return 1; } +TableVal* TableVal::Intersect(const TableVal* tv) const + { + TableVal* result = new TableVal(table_type); + + const PDict(TableEntryVal)* t1 = tv->AsTable(); + const PDict(TableEntryVal)* t2 = AsTable(); + const PDict(TableEntryVal)* t3 = result->AsTable(); + + // Figure out which is smaller. + if ( t1->Length() > t2->Length() ) + { // Swap. + const PDict(TableEntryVal)* t3 = t1; + t1 = t2; + t2 = t3; + } + + IterCookie* c = t1->InitForIteration(); + HashKey* k; + while ( t1->NextEntry(k, c) ) + { +//### // Here we leverage the same assumption about consistent +//### // hashes as in TableVal::RemoveFrom above. +//### if ( t2->Lookup(k) ) +//### { +//### Val* index = RecoverIndex(); +//### result-> +//### +//### Unref(index); +//### Unref(t->Delete(k)); +//### delete k; + } + + return result; + } + int TableVal::ExpandAndInit(Val* index, Val* new_val) { BroType* index_type = index->Type(); diff --git a/src/Val.h b/src/Val.h index 771ed40dd1..ef2a8eefd6 100644 --- a/src/Val.h +++ b/src/Val.h @@ -809,6 +809,12 @@ public: // Returns true if the addition typechecked, false if not. int RemoveFrom(Val* v) const override; + // Returns a new table that is the intersection of this + // table and the given table. Intersection is just done + // on index, not on yield value, so this really only makes + // sense for sets. + TableVal* Intersect(const TableVal* v) const; + // Expands any lists in the index into multiple initializations. // Returns true if the initializations typecheck, false if not. int ExpandAndInit(Val* index, Val* new_val);