unify functionality across EqExpr and RelExpr classes

This commit is contained in:
Vern Paxson 2024-04-08 18:19:27 -04:00 committed by Tim Wojtulewicz
parent 1b838ca91d
commit d15d4a6e08
3 changed files with 138 additions and 48 deletions

View file

@ -1872,13 +1872,46 @@ BitExpr::BitExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
ExprError("requires \"count\" or compatible \"set\" operands");
}
EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
: BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) {
CmpExpr::CmpExpr(ExprTag tag, ExprPtr _op1, ExprPtr _op2) : BinaryExpr(tag, std::move(_op1), std::move(_op2)) {
if ( IsError() )
return;
Canonicalize();
if ( is_vector(op1) )
SetType(make_intrusive<VectorType>(base_type(TYPE_BOOL)));
else
SetType(base_type(TYPE_BOOL));
}
void CmpExpr::Canonicalize() {
if ( tag == EXPR_EQ || tag == EXPR_NE ) {
if ( op2->GetType()->Tag() == TYPE_PATTERN )
SwapOps();
else if ( op1->GetType()->Tag() == TYPE_PATTERN )
;
else if ( expr_greater(op2.get(), op1.get()) )
SwapOps();
}
else if ( tag == EXPR_GT ) {
SwapOps();
tag = EXPR_LT;
}
else if ( tag == EXPR_GE ) {
SwapOps();
tag = EXPR_LE;
}
}
EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
: CmpExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) {
if ( IsError() )
return;
const auto& t1 = op1->GetType();
const auto& t2 = op2->GetType();
@ -1886,11 +1919,6 @@ EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) )
return;
if ( is_vector(op1) )
SetType(make_intrusive<VectorType>(base_type(TYPE_BOOL)));
else
SetType(base_type(TYPE_BOOL));
if ( BothArithmetic(bt1, bt2) )
PromoteOps(max_type(bt1, bt2));
@ -1936,17 +1964,6 @@ EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
ExprError("type clash in comparison");
}
void EqExpr::Canonicalize() {
if ( op2->GetType()->Tag() == TYPE_PATTERN )
SwapOps();
else if ( op1->GetType()->Tag() == TYPE_PATTERN )
;
else if ( expr_greater(op2.get(), op1.get()) )
SwapOps();
}
ValPtr EqExpr::Fold(Val* v1, Val* v2) const {
if ( op1->GetType()->Tag() == TYPE_PATTERN ) {
auto re = v1->As<PatternVal*>();
@ -1971,12 +1988,10 @@ bool EqExpr::InvertSense() {
}
RelExpr::RelExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
: BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) {
: CmpExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) {
if ( IsError() )
return;
Canonicalize();
const auto& t1 = op1->GetType();
const auto& t2 = op2->GetType();
@ -1984,11 +1999,6 @@ RelExpr::RelExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) )
return;
if ( is_vector(op1) )
SetType(make_intrusive<VectorType>(base_type(TYPE_BOOL)));
else
SetType(base_type(TYPE_BOOL));
if ( BothArithmetic(bt1, bt2) )
PromoteOps(max_type(bt1, bt2));
@ -2004,18 +2014,6 @@ RelExpr::RelExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
ExprError("illegal comparison");
}
void RelExpr::Canonicalize() {
if ( tag == EXPR_GT ) {
SwapOps();
tag = EXPR_LT;
}
else if ( tag == EXPR_GE ) {
SwapOps();
tag = EXPR_LE;
}
}
bool RelExpr::InvertSense() {
switch ( tag ) {
case EXPR_LT: tag = EXPR_GE; return true;

View file

@ -822,14 +822,28 @@ public:
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
};
class EqExpr final : public BinaryExpr {
// Intermediary class for comparison operators. Not directly instantiated.
class CmpExpr : public BinaryExpr {
protected:
CmpExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
void Canonicalize() override;
bool WillTransform(Reducer* c) const override;
bool WillTransformInConditional(Reducer* c) const override;
bool IsReduced(Reducer* c) const override;
ExprPtr TransformToConditional(Reducer* c, StmtPtr& red_stmt) override;
bool IsHasElementsTest() const;
ExprPtr BuildHasElementsTest() const;
};
class EqExpr final : public CmpExpr {
public:
EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
void Canonicalize() override;
// Optimization-related:
ExprPtr Duplicate() override;
bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
bool InvertSense() override;
@ -837,14 +851,12 @@ protected:
ValPtr Fold(Val* v1, Val* v2) const override;
};
class RelExpr final : public BinaryExpr {
class RelExpr final : public CmpExpr {
public:
RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
void Canonicalize() override;
// Optimization-related:
ExprPtr Duplicate() override;
bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
bool InvertSense() override;
};

View file

@ -1144,15 +1144,94 @@ ExprPtr BitExpr::Reduce(Reducer* c, StmtPtr& red_stmt) {
return BinaryExpr::Reduce(c, red_stmt);
}
bool CmpExpr::WillTransform(Reducer* c) const {
if ( IsHasElementsTest() )
return true;
return GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2);
}
bool CmpExpr::WillTransformInConditional(Reducer* c) const { return WillTransform(c); }
bool CmpExpr::IsReduced(Reducer* c) const {
if ( IsHasElementsTest() )
return NonReduced(this);
return true;
}
static std::map<ExprTag, ExprTag> has_elements_swap_tag = {
{EXPR_EQ, EXPR_EQ}, {EXPR_NE, EXPR_NE}, {EXPR_LT, EXPR_GT},
{EXPR_LE, EXPR_GE}, {EXPR_GE, EXPR_LE}, {EXPR_GT, EXPR_LT},
};
bool CmpExpr::IsHasElementsTest() const {
static std::set<ExprTag> rel_tags = {EXPR_EQ, EXPR_NE, EXPR_LT, EXPR_LE, EXPR_GE, EXPR_GT};
auto t = Tag(); // note, we may invert t below
if ( rel_tags.count(t) == 0 )
return false;
auto op1 = GetOp1();
auto op2 = GetOp2();
ASSERT(op1 && op2);
if ( op1->Tag() != EXPR_SIZE && op2->Tag() != EXPR_SIZE )
return false;
if ( ! op1->IsZero() && ! op1->IsOne() && ! op2->IsZero() && ! op2->IsOne() )
return false;
if ( op1->Tag() == EXPR_CONST ) {
t = has_elements_swap_tag[t];
std::swap(op1, op2);
}
auto op1_t = op1->GetOp1()->GetType()->Tag();
if ( op1_t != TYPE_TABLE && op1_t != TYPE_VECTOR )
return false;
static std::map<ExprTag, bool> zero_req = {
{EXPR_EQ, true}, {EXPR_NE, true}, {EXPR_LT, false}, {EXPR_LE, true}, {EXPR_GE, false}, {EXPR_GT, true},
};
return zero_req[t] ? op2->IsZero() : op2->IsOne();
}
ExprPtr CmpExpr::TransformToConditional(Reducer* c, StmtPtr& red_stmt) { return BuildHasElementsTest(); }
ExprPtr CmpExpr::BuildHasElementsTest() const {
auto t = Tag();
auto op1 = GetOp1();
auto op2 = GetOp2();
if ( op1->Tag() == EXPR_CONST ) {
t = has_elements_swap_tag[t];
std::swap(op1, op2);
}
ExprPtr he =
with_location_of(make_intrusive<ScriptOptBuiltinExpr>(ScriptOptBuiltinExpr::HAS_ELEMENTS, op1->GetOp1()), this);
static std::map<ExprTag, bool> has_elements = {
{EXPR_EQ, false}, {EXPR_NE, true}, {EXPR_LT, false}, {EXPR_LE, false}, {EXPR_GE, true}, {EXPR_GT, true},
};
if ( ! has_elements[t] )
he = with_location_of(make_intrusive<NotExpr>(he), this);
return he;
}
ExprPtr EqExpr::Duplicate() {
auto op1_d = op1->Duplicate();
auto op2_d = op2->Duplicate();
return SetSucc(new EqExpr(tag, op1_d, op2_d));
}
bool EqExpr::WillTransform(Reducer* c) const { return GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2); }
ExprPtr EqExpr::Reduce(Reducer* c, StmtPtr& red_stmt) {
if ( IsHasElementsTest() )
return BuildHasElementsTest()->Reduce(c, red_stmt);
if ( GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2) ) {
bool t = Tag() == EXPR_EQ;
auto res = with_location_of(make_intrusive<ConstExpr>(val_mgr->Bool(t)), this);
@ -1168,9 +1247,10 @@ ExprPtr RelExpr::Duplicate() {
return SetSucc(new RelExpr(tag, op1_d, op2_d));
}
bool RelExpr::WillTransform(Reducer* c) const { return GetType()->Tag() == TYPE_BOOL && same_singletons(op1, op2); }
ExprPtr RelExpr::Reduce(Reducer* c, StmtPtr& red_stmt) {
if ( IsHasElementsTest() )
return BuildHasElementsTest()->Reduce(c, red_stmt);
if ( GetType()->Tag() == TYPE_BOOL ) {
if ( same_singletons(op1, op2) ) {
bool t = Tag() == EXPR_GE || Tag() == EXPR_LE;