diff --git a/src/Expr.cc b/src/Expr.cc index 8cd3bb5200..951fe0fe1b 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -1872,13 +1872,46 @@ BitExpr::BitExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) ExprError("requires \"count\" or compatible \"set\" operands"); } -EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { +CmpExpr::CmpExpr(ExprTag tag, ExprPtr _op1, ExprPtr _op2) : BinaryExpr(tag, std::move(_op1), std::move(_op2)) { if ( IsError() ) return; Canonicalize(); + if ( is_vector(op1) ) + SetType(make_intrusive(base_type(TYPE_BOOL))); + else + SetType(base_type(TYPE_BOOL)); +} + +void CmpExpr::Canonicalize() { + if ( tag == EXPR_EQ || tag == EXPR_NE ) { + if ( op2->GetType()->Tag() == TYPE_PATTERN ) + SwapOps(); + + else if ( op1->GetType()->Tag() == TYPE_PATTERN ) + ; + + else if ( expr_greater(op2.get(), op1.get()) ) + SwapOps(); + } + + else if ( tag == EXPR_GT ) { + SwapOps(); + tag = EXPR_LT; + } + + else if ( tag == EXPR_GE ) { + SwapOps(); + tag = EXPR_LE; + } +} + +EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) + : CmpExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; + const auto& t1 = op1->GetType(); const auto& t2 = op2->GetType(); @@ -1886,11 +1919,6 @@ EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) return; - if ( is_vector(op1) ) - SetType(make_intrusive(base_type(TYPE_BOOL))); - else - SetType(base_type(TYPE_BOOL)); - if ( BothArithmetic(bt1, bt2) ) PromoteOps(max_type(bt1, bt2)); @@ -1936,17 +1964,6 @@ EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) ExprError("type clash in comparison"); } -void EqExpr::Canonicalize() { - if ( op2->GetType()->Tag() == TYPE_PATTERN ) - SwapOps(); - - else if ( op1->GetType()->Tag() == TYPE_PATTERN ) - ; - - else if ( expr_greater(op2.get(), op1.get()) ) - SwapOps(); -} - ValPtr EqExpr::Fold(Val* v1, Val* v2) const { if ( op1->GetType()->Tag() == TYPE_PATTERN ) { auto re = v1->As(); @@ -1971,12 +1988,10 @@ bool EqExpr::InvertSense() { } RelExpr::RelExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { + : CmpExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { if ( IsError() ) return; - Canonicalize(); - const auto& t1 = op1->GetType(); const auto& t2 = op2->GetType(); @@ -1984,11 +1999,6 @@ RelExpr::RelExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) return; - if ( is_vector(op1) ) - SetType(make_intrusive(base_type(TYPE_BOOL))); - else - SetType(base_type(TYPE_BOOL)); - if ( BothArithmetic(bt1, bt2) ) PromoteOps(max_type(bt1, bt2)); @@ -2004,18 +2014,6 @@ RelExpr::RelExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) ExprError("illegal comparison"); } -void RelExpr::Canonicalize() { - if ( tag == EXPR_GT ) { - SwapOps(); - tag = EXPR_LT; - } - - else if ( tag == EXPR_GE ) { - SwapOps(); - tag = EXPR_LE; - } -} - bool RelExpr::InvertSense() { switch ( tag ) { case EXPR_LT: tag = EXPR_GE; return true; diff --git a/src/Expr.h b/src/Expr.h index 2f9ee0d722..5e8abbab4a 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -822,14 +822,28 @@ public: ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; }; -class EqExpr final : public BinaryExpr { +// Intermediary class for comparison operators. Not directly instantiated. +class CmpExpr : public BinaryExpr { +protected: + CmpExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); + + void Canonicalize() override; + + bool WillTransform(Reducer* c) const override; + bool WillTransformInConditional(Reducer* c) const override; + bool IsReduced(Reducer* c) const override; + ExprPtr TransformToConditional(Reducer* c, StmtPtr& red_stmt) override; + + bool IsHasElementsTest() const; + ExprPtr BuildHasElementsTest() const; +}; + +class EqExpr final : public CmpExpr { public: EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); - void Canonicalize() override; // Optimization-related: ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; bool InvertSense() override; @@ -837,14 +851,12 @@ protected: ValPtr Fold(Val* v1, Val* v2) const override; }; -class RelExpr final : public BinaryExpr { +class RelExpr final : public CmpExpr { public: RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); - void Canonicalize() override; // Optimization-related: ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; bool InvertSense() override; }; diff --git a/src/script_opt/Expr.cc b/src/script_opt/Expr.cc index e552d898e3..0a917d5106 100644 --- a/src/script_opt/Expr.cc +++ b/src/script_opt/Expr.cc @@ -1144,15 +1144,94 @@ ExprPtr BitExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { return BinaryExpr::Reduce(c, red_stmt); } +bool CmpExpr::WillTransform(Reducer* c) const { + if ( IsHasElementsTest() ) + return true; + return GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2); +} + +bool CmpExpr::WillTransformInConditional(Reducer* c) const { return WillTransform(c); } + +bool CmpExpr::IsReduced(Reducer* c) const { + if ( IsHasElementsTest() ) + return NonReduced(this); + return true; +} + +static std::map has_elements_swap_tag = { + {EXPR_EQ, EXPR_EQ}, {EXPR_NE, EXPR_NE}, {EXPR_LT, EXPR_GT}, + {EXPR_LE, EXPR_GE}, {EXPR_GE, EXPR_LE}, {EXPR_GT, EXPR_LT}, +}; + +bool CmpExpr::IsHasElementsTest() const { + static std::set rel_tags = {EXPR_EQ, EXPR_NE, EXPR_LT, EXPR_LE, EXPR_GE, EXPR_GT}; + + auto t = Tag(); // note, we may invert t below + if ( rel_tags.count(t) == 0 ) + return false; + + auto op1 = GetOp1(); + auto op2 = GetOp2(); + + ASSERT(op1 && op2); + + if ( op1->Tag() != EXPR_SIZE && op2->Tag() != EXPR_SIZE ) + return false; + + if ( ! op1->IsZero() && ! op1->IsOne() && ! op2->IsZero() && ! op2->IsOne() ) + return false; + + if ( op1->Tag() == EXPR_CONST ) { + t = has_elements_swap_tag[t]; + std::swap(op1, op2); + } + + auto op1_t = op1->GetOp1()->GetType()->Tag(); + if ( op1_t != TYPE_TABLE && op1_t != TYPE_VECTOR ) + return false; + + static std::map zero_req = { + {EXPR_EQ, true}, {EXPR_NE, true}, {EXPR_LT, false}, {EXPR_LE, true}, {EXPR_GE, false}, {EXPR_GT, true}, + }; + + return zero_req[t] ? op2->IsZero() : op2->IsOne(); +} + +ExprPtr CmpExpr::TransformToConditional(Reducer* c, StmtPtr& red_stmt) { return BuildHasElementsTest(); } + +ExprPtr CmpExpr::BuildHasElementsTest() const { + auto t = Tag(); + auto op1 = GetOp1(); + auto op2 = GetOp2(); + + if ( op1->Tag() == EXPR_CONST ) { + t = has_elements_swap_tag[t]; + std::swap(op1, op2); + } + + ExprPtr he = + with_location_of(make_intrusive(ScriptOptBuiltinExpr::HAS_ELEMENTS, op1->GetOp1()), this); + + static std::map has_elements = { + {EXPR_EQ, false}, {EXPR_NE, true}, {EXPR_LT, false}, {EXPR_LE, false}, {EXPR_GE, true}, {EXPR_GT, true}, + }; + + if ( ! has_elements[t] ) + he = with_location_of(make_intrusive(he), this); + + return he; +} + ExprPtr EqExpr::Duplicate() { auto op1_d = op1->Duplicate(); auto op2_d = op2->Duplicate(); return SetSucc(new EqExpr(tag, op1_d, op2_d)); } -bool EqExpr::WillTransform(Reducer* c) const { return GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2); } - ExprPtr EqExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { + if ( IsHasElementsTest() ) + return BuildHasElementsTest()->Reduce(c, red_stmt); + if ( GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2) ) { bool t = Tag() == EXPR_EQ; auto res = with_location_of(make_intrusive(val_mgr->Bool(t)), this); @@ -1168,9 +1247,10 @@ ExprPtr RelExpr::Duplicate() { return SetSucc(new RelExpr(tag, op1_d, op2_d)); } -bool RelExpr::WillTransform(Reducer* c) const { return GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2); } - ExprPtr RelExpr::Reduce(Reducer* c, StmtPtr& red_stmt) { + if ( IsHasElementsTest() ) + return BuildHasElementsTest()->Reduce(c, red_stmt); + if ( GetType()->Tag() == TYPE_BOOL ) { if ( same_singletons(op1, op2) ) { bool t = Tag() == EXPR_GE || Tag() == EXPR_LE;