introduce internal "mask" expression node to avoid mixed-type overloading of "/"

This commit is contained in:
Vern Paxson 2023-09-26 14:39:26 -07:00
parent b53a025b1e
commit 434a7e059d
6 changed files with 55 additions and 13 deletions

View file

@ -48,6 +48,7 @@ const char* expr_name(ExprTag t)
"-=",
"*",
"/",
"/", // mask operator
"%",
"&",
"|",
@ -1918,14 +1919,27 @@ DivideExpr::DivideExpr(ExprPtr arg_op1, ExprPtr arg_op2)
else if ( BothArithmetic(bt1, bt2) )
PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2));
else if ( bt1 == TYPE_ADDR && ! is_vector(op2) && (bt2 == TYPE_COUNT || bt2 == TYPE_INT) )
SetType(base_type(TYPE_SUBNET));
else
ExprError("requires arithmetic operands");
}
ValPtr DivideExpr::AddrFold(Val* v1, Val* v2) const
MaskExpr::MaskExpr(ExprPtr arg_op1, ExprPtr arg_op2)
: BinaryExpr(EXPR_MASK, std::move(arg_op1), std::move(arg_op2))
{
if ( IsError() )
return;
TypeTag bt1, bt2;
if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) )
return;
if ( bt1 == TYPE_ADDR && ! is_vector(op2) && (bt2 == TYPE_COUNT || bt2 == TYPE_INT) )
SetType(base_type(TYPE_SUBNET));
else
ExprError("requires address LHS and count/int RHS");
}
ValPtr MaskExpr::AddrFold(Val* v1, Val* v2) const
{
uint32_t mask;

View file

@ -52,6 +52,7 @@ enum ExprTag : int
EXPR_REMOVE_FROM,
EXPR_TIMES,
EXPR_DIVIDE,
EXPR_MASK,
EXPR_MOD,
EXPR_AND,
EXPR_OR,
@ -819,6 +820,15 @@ public:
ExprPtr Duplicate() override;
bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
};
class MaskExpr final : public BinaryExpr
{
public:
MaskExpr(ExprPtr op1, ExprPtr op2);
// Optimization-related:
ExprPtr Duplicate() override;
protected:
ValPtr AddrFold(Val* v1, Val* v2) const override;

View file

@ -633,7 +633,10 @@ expr:
| expr '/' expr
{
set_location(@1, @3);
$$ = new DivideExpr({AdoptRef{}, $1}, {AdoptRef{}, $3});
if ( $1->GetType()->Tag() == TYPE_ADDR )
$$ = new MaskExpr({AdoptRef{}, $1}, {AdoptRef{}, $3});
else
$$ = new DivideExpr({AdoptRef{}, $1}, {AdoptRef{}, $3});
}
| expr '%' expr

View file

@ -80,6 +80,7 @@ string CPPCompile::GenExpr(const Expr* e, GenType gt, bool top_level)
case EXPR_TIMES:
return GenBinary(e, gt, "*", "mul");
case EXPR_DIVIDE:
case EXPR_MASK: // later code will split into addr masking
return GenBinary(e, gt, "/", "div");
case EXPR_MOD:
return GenBinary(e, gt, "%", "mod");
@ -1060,7 +1061,7 @@ string CPPCompile::GenBinaryAddr(const Expr* e, GenType gt, const char* op)
{
auto v1 = GenExpr(e->GetOp1(), GEN_DONT_CARE) + "->AsAddr()";
if ( e->Tag() == EXPR_DIVIDE )
if ( e->Tag() == EXPR_MASK )
{
auto gen = string("addr_mask__CPP(") + v1 + ", " + GenExpr(e->GetOp2(), GEN_NATIVE) + ")";

View file

@ -253,6 +253,7 @@ bool Expr::IsFieldAssignable(const Expr* e) const
case EXPR_SUB:
case EXPR_TIMES:
case EXPR_DIVIDE:
case EXPR_MASK:
case EXPR_MOD:
case EXPR_AND:
case EXPR_OR:
@ -1091,20 +1092,24 @@ ExprPtr DivideExpr::Duplicate()
bool DivideExpr::WillTransform(Reducer* c) const
{
return GetType()->Tag() != TYPE_SUBNET && op2->IsOne();
return op2->IsOne();
}
ExprPtr DivideExpr::Reduce(Reducer* c, StmtPtr& red_stmt)
{
if ( GetType()->Tag() != TYPE_SUBNET )
{
if ( op2->IsOne() )
return op1->ReduceToSingleton(c, red_stmt);
}
if ( op2->IsOne() )
return op1->ReduceToSingleton(c, red_stmt);
return BinaryExpr::Reduce(c, red_stmt);
}
ExprPtr MaskExpr::Duplicate()
{
auto op1_d = op1->Duplicate();
auto op2_d = op2->Duplicate();
return SetSucc(new MaskExpr(op1_d, op2_d));
}
ExprPtr ModExpr::Duplicate()
{
auto op1_d = op1->Duplicate();

View file

@ -490,7 +490,16 @@ eval-pre if ( $2 == 0 )
break;
}
eval $1 / $2
#
binary-expr-op Mask
op-type I
vector
### Note that this first "eval" is a dummy - we'll never generate code
### that uses it because "Mask" expressions don't have LHS operands of
### type "int". We could omit this if we modified Gen-ZAM to understand
### that an op-type of 'X' for a binary-expr-op means "skip the usual case
### of two operands of the same type".
eval $1 / $2
eval-mixed A I auto mask = static_cast<uint32_t>($2);
auto a = $1->AsAddr();
if ( a.GetFamily() == IPv4 && mask > 32 )