Skip to content

Commit 5221634

Browse files
authored
Do not trigger UB during AffineExpr parsing. (#96896)
Currently, parsing expressions that are undefined will trigger UB during compilation (e.g. `9223372036854775807 * 2`). This change instead leaves the expressions as they were. This change is an NFC for compilations that did not previously involve UB.
1 parent fcffb2c commit 5221634

File tree

3 files changed

+84
-12
lines changed

3 files changed

+84
-12
lines changed

llvm/include/llvm/Support/MathExtras.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
435435
}
436436

437437
/// Returns the integer ceil(Numerator / Denominator). Signed version.
438-
/// Guaranteed to never overflow.
438+
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
439+
/// is -1.
439440
inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
440441
assert(Denominator && "Division by zero");
441442
if (!Numerator)
@@ -448,7 +449,8 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
448449
}
449450

450451
/// Returns the integer floor(Numerator / Denominator). Signed version.
451-
/// Guaranteed to never overflow.
452+
/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
453+
/// is -1.
452454
inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
453455
assert(Denominator && "Division by zero");
454456
if (!Numerator)

mlir/lib/IR/AffineExpr.cpp

+34-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <cstdint>
10+
#include <limits>
911
#include <utility>
1012

1113
#include "AffineExprDetail.h"
@@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
645647
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
646648
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
647649
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
648-
// Fold if both LHS, RHS are a constant.
649-
if (lhsConst && rhsConst)
650-
return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
651-
lhs.getContext());
650+
// Fold if both LHS, RHS are a constant and the sum does not overflow.
651+
if (lhsConst && rhsConst) {
652+
int64_t sum;
653+
if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
654+
return nullptr;
655+
}
656+
return getAffineConstantExpr(sum, lhs.getContext());
657+
}
652658

653659
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
654660
// If only one of them is a symbolic expressions, make it the RHS.
@@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
774780
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
775781
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
776782

777-
if (lhsConst && rhsConst)
778-
return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
779-
lhs.getContext());
783+
if (lhsConst && rhsConst) {
784+
int64_t product;
785+
if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
786+
return nullptr;
787+
}
788+
return getAffineConstantExpr(product, lhs.getContext());
789+
}
780790

781791
if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
782792
return nullptr;
@@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
849859
if (!rhsConst || rhsConst.getValue() < 1)
850860
return nullptr;
851861

852-
if (lhsConst)
862+
if (lhsConst) {
863+
// divideFloorSigned can only overflow in this case:
864+
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
865+
rhsConst.getValue() == -1) {
866+
return nullptr;
867+
}
853868
return getAffineConstantExpr(
854869
divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
855870
lhs.getContext());
871+
}
856872

857873
// Fold floordiv of a multiply with a constant that is a multiple of the
858874
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
@@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
905921
if (!rhsConst || rhsConst.getValue() < 1)
906922
return nullptr;
907923

908-
if (lhsConst)
924+
if (lhsConst) {
925+
// divideCeilSigned can only overflow in this case:
926+
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
927+
rhsConst.getValue() == -1) {
928+
return nullptr;
929+
}
909930
return getAffineConstantExpr(
910931
divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
911932
lhs.getContext());
933+
}
912934

913935
// Fold ceildiv of a multiply with a constant that is a multiple of the
914936
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
@@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
950972
if (!rhsConst || rhsConst.getValue() < 1)
951973
return nullptr;
952974

953-
if (lhsConst)
975+
if (lhsConst) {
976+
// mod never overflows.
954977
return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
955978
lhs.getContext());
979+
}
956980

957981
// Fold modulo of an expression that is known to be a multiple of a constant
958982
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)

mlir/unittests/IR/AffineExprTest.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <cstdint>
10+
#include <limits>
11+
912
#include "mlir/IR/AffineExpr.h"
1013
#include "mlir/IR/Builders.h"
1114
#include "gtest/gtest.h"
@@ -30,3 +33,46 @@ TEST(AffineExprTest, constructFromBinaryOperators) {
3033
ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
3134
ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
3235
}
36+
37+
TEST(AffineExprTest, constantFolding) {
38+
MLIRContext ctx;
39+
OpBuilder b(&ctx);
40+
auto cn1 = b.getAffineConstantExpr(-1);
41+
auto c0 = b.getAffineConstantExpr(0);
42+
auto c1 = b.getAffineConstantExpr(1);
43+
auto c2 = b.getAffineConstantExpr(2);
44+
auto c3 = b.getAffineConstantExpr(3);
45+
auto c6 = b.getAffineConstantExpr(6);
46+
auto cmax = b.getAffineConstantExpr(std::numeric_limits<int64_t>::max());
47+
auto cmin = b.getAffineConstantExpr(std::numeric_limits<int64_t>::min());
48+
49+
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3);
50+
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6);
51+
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1);
52+
ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2);
53+
54+
// Test division by zero:
55+
auto c3ceildivc0 = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c0);
56+
ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv);
57+
58+
auto c3floordivc0 = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c0);
59+
ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv);
60+
61+
auto c3modc0 = getAffineBinaryOpExpr(AffineExprKind::Mod, c3, c0);
62+
ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod);
63+
64+
// Test overflow:
65+
auto cmaxplusc1 = getAffineBinaryOpExpr(AffineExprKind::Add, cmax, c1);
66+
ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add);
67+
68+
auto cmaxtimesc2 = getAffineBinaryOpExpr(AffineExprKind::Mul, cmax, c2);
69+
ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul);
70+
71+
auto cminceildivcn1 =
72+
getAffineBinaryOpExpr(AffineExprKind::CeilDiv, cmin, cn1);
73+
ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv);
74+
75+
auto cminfloordivcn1 =
76+
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
77+
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
78+
}

0 commit comments

Comments
 (0)