|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
| 9 | +#include <cstdint> |
| 10 | +#include <limits> |
9 | 11 | #include <utility>
|
10 | 12 |
|
11 | 13 | #include "AffineExprDetail.h"
|
@@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
|
645 | 647 | static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
|
646 | 648 | auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
647 | 649 | auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
648 |
| - // Fold if both LHS, RHS are a constant. |
649 |
| - if (lhsConst && rhsConst) |
650 |
| - return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), |
651 |
| - lhs.getContext()); |
| 650 | + // Fold if both LHS, RHS are a constant and the sum does not overflow. |
| 651 | + if (lhsConst && rhsConst) { |
| 652 | + int64_t sum; |
| 653 | + if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) { |
| 654 | + return nullptr; |
| 655 | + } |
| 656 | + return getAffineConstantExpr(sum, lhs.getContext()); |
| 657 | + } |
652 | 658 |
|
653 | 659 | // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
|
654 | 660 | // If only one of them is a symbolic expressions, make it the RHS.
|
@@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
|
774 | 780 | auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
|
775 | 781 | auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
|
776 | 782 |
|
777 |
| - if (lhsConst && rhsConst) |
778 |
| - return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), |
779 |
| - lhs.getContext()); |
| 783 | + if (lhsConst && rhsConst) { |
| 784 | + int64_t product; |
| 785 | + if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) { |
| 786 | + return nullptr; |
| 787 | + } |
| 788 | + return getAffineConstantExpr(product, lhs.getContext()); |
| 789 | + } |
780 | 790 |
|
781 | 791 | if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
|
782 | 792 | return nullptr;
|
@@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
|
849 | 859 | if (!rhsConst || rhsConst.getValue() < 1)
|
850 | 860 | return nullptr;
|
851 | 861 |
|
852 |
| - if (lhsConst) |
| 862 | + if (lhsConst) { |
| 863 | + // divideFloorSigned can only overflow in this case: |
| 864 | + if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() && |
| 865 | + rhsConst.getValue() == -1) { |
| 866 | + return nullptr; |
| 867 | + } |
853 | 868 | return getAffineConstantExpr(
|
854 | 869 | divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
|
855 | 870 | lhs.getContext());
|
| 871 | + } |
856 | 872 |
|
857 | 873 | // Fold floordiv of a multiply with a constant that is a multiple of the
|
858 | 874 | // divisor. Eg: (i * 128) floordiv 64 = i * 2.
|
@@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
|
905 | 921 | if (!rhsConst || rhsConst.getValue() < 1)
|
906 | 922 | return nullptr;
|
907 | 923 |
|
908 |
| - if (lhsConst) |
| 924 | + if (lhsConst) { |
| 925 | + // divideCeilSigned can only overflow in this case: |
| 926 | + if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() && |
| 927 | + rhsConst.getValue() == -1) { |
| 928 | + return nullptr; |
| 929 | + } |
909 | 930 | return getAffineConstantExpr(
|
910 | 931 | divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
|
911 | 932 | lhs.getContext());
|
| 933 | + } |
912 | 934 |
|
913 | 935 | // Fold ceildiv of a multiply with a constant that is a multiple of the
|
914 | 936 | // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
|
@@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
|
950 | 972 | if (!rhsConst || rhsConst.getValue() < 1)
|
951 | 973 | return nullptr;
|
952 | 974 |
|
953 |
| - if (lhsConst) |
| 975 | + if (lhsConst) { |
| 976 | + // mod never overflows. |
954 | 977 | return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
|
955 | 978 | lhs.getContext());
|
| 979 | + } |
956 | 980 |
|
957 | 981 | // Fold modulo of an expression that is known to be a multiple of a constant
|
958 | 982 | // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
|
|
0 commit comments