diff --git a/src/Expr.cc b/src/Expr.cc index dde0575490..b073db85b6 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -1390,6 +1390,36 @@ void AddExpr::Canonicalize() { SwapOps(); } +// True if we should treat LHS += RHS as add-every-element-of-RHS-to-LHS. +// False for the alternative, add-RHS-as-one-element-to-LHS. +// +// Assumes (1) LHS has already been confirmed as a vector, (2) the +// "LHS += RHS" expression has been type-checked. + +static bool is_element_wise_vector_append(const TypePtr& lhs, const TypePtr& rhs) { + if ( ! IsVector(rhs->Tag()) ) + // Can't be add-every-element since RHS isn't even a vector. + return false; + + if ( ! same_type(lhs, rhs) ) + // Can't be add-every-element since they're different types of vectors. + return false; + + if ( lhs->Yield()->Tag() != TYPE_VECTOR ) + // LHS is not a vector-of-vector, and RHS is a vector, so + // clearly we're doing element-wise-append. + return true; + + if ( rhs->AsVectorType()->IsUnspecifiedVector() ) + // This is a vector-of-vector-of-X += vector() construct. + // It is *not* treated as element-wise-append of an empty RHS, + // instead append an empty vector to the LHS. + return false; + + // RHS is a compatible element-wise-append vector for LHS. + return true; +} + AddToExpr::AddToExpr(ExprPtr arg_op1, ExprPtr arg_op2) : BinaryExpr(EXPR_ADD_TO, std::move(arg_op1), std::move(arg_op2)) { if ( IsError() ) @@ -1426,9 +1456,9 @@ AddToExpr::AddToExpr(ExprPtr arg_op1, ExprPtr arg_op2) } else if ( IsVector(bt1) ) { - // We need the IsVector(bt2) check in the following because - // same_type() always treats "any" types as "same". - if ( IsVector(bt2) && same_type(t1, t2) ) { + // Treat += of two vectors as appending each element + // of the RHS to the LHS if types agree. + if ( is_element_wise_vector_append(t1, t2) ) { SetType(t1); return; } diff --git a/src/Expr.h b/src/Expr.h index bc0990fcec..fc22549ca3 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -713,6 +713,8 @@ public: AddToExpr(ExprPtr op1, ExprPtr op2); ValPtr Eval(Frame* f) const override; + bool IsVectorElemAppend() const { return is_vector_elem_append; } + // Optimization-related: bool IsPure() const override { return false; } ExprPtr Duplicate() override; diff --git a/src/script_opt/CPP/Exprs.cc b/src/script_opt/CPP/Exprs.cc index ed0591ed51..9cdcba0a44 100644 --- a/src/script_opt/CPP/Exprs.cc +++ b/src/script_opt/CPP/Exprs.cc @@ -481,10 +481,10 @@ string CPPCompile::GenAddToExpr(const Expr* e, GenType gt, bool top_level) { if ( t->Tag() == TYPE_VECTOR ) { auto& rt = rhs->GetType(); - if ( IsVector(rt->Tag()) && same_type(lhs->GetType(), rt) ) - add_to_func = "vector_vec_append__CPP"; - else + if ( e->AsAddToExpr()->IsVectorElemAppend() ) add_to_func = "vector_append__CPP"; + else + add_to_func = "vector_vec_append__CPP"; } else if ( t->Tag() == TYPE_PATTERN ) diff --git a/src/script_opt/Expr.cc b/src/script_opt/Expr.cc index fe6e789dc2..235a5f09f0 100644 --- a/src/script_opt/Expr.cc +++ b/src/script_opt/Expr.cc @@ -764,7 +764,7 @@ ExprPtr AddToExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { red_stmt = MergeStmts(red_stmt1, red_stmt2); - if ( tag == TYPE_VECTOR && (! IsVector(op2->GetType()->Tag()) || ! same_type(t, op2->GetType())) ) { + if ( is_vector_elem_append ) { auto append = with_location_of(make_intrusive(op1->Duplicate(), op2), this); auto append_stmt = with_location_of(make_intrusive(append), this); diff --git a/testing/btest/Baseline/language.vector/out b/testing/btest/Baseline/language.vector/out index a4f4bce8d3..5887d66595 100644 --- a/testing/btest/Baseline/language.vector/out +++ b/testing/btest/Baseline/language.vector/out @@ -82,3 +82,5 @@ left shift (PASS) right shift (PASS) negative index (PASS) negative index (PASS) ++= of empty vector (PASS) ++= of empty vector (PASS) diff --git a/testing/btest/language/vector.zeek b/testing/btest/language/vector.zeek index d8bb810181..820261e0e8 100644 --- a/testing/btest/language/vector.zeek +++ b/testing/btest/language/vector.zeek @@ -229,4 +229,16 @@ event zeek_init() test_case( "negative index", v24[-1] == 5 ); test_case( "negative index", v24[-3] == 3 ); + # For a vector-of-vectors, += of an empty vector should append it as + # a single element, not all of its elements (= nothing gets appended). + local v25: vector of vector of count; + v25 += vector(); + test_case( "+= of empty vector", |v25| == 1 ); + + # OTOH, for a vector-of-non-vectors, it should result in appending + # nothing. + local v26: vector of set[count]; + v26 += vector(); + test_case( "+= of empty vector", |v26| == 0 ); + }