Skip to content

[WIP][DAG] Introduce generic shl_add node [NFC] #88791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,13 @@ enum NodeType {
FSHL,
FSHR,

// Represents (ADD (SHL a, b), c) with the arguments appearing in the order
// a, b, c. 'b' must be a constant, and follows the rules for shift amount
// types described just above. This is used soley post-legalization when
// lowering MUL to target specific instructions - e.g. LEA on x86 or
// sh1add/sh2add/sh3add on RISCV.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RISCV -> RISC-V

SHL_ADD,

/// Byte Swap and Counting operators.
BSWAP,
CTTZ,
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl
def SDTIntShiftDOp: SDTypeProfile<1, 3, [ // fshl, fshr
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<3>
]>;
def SDTIntShiftAddOp : SDTypeProfile<1, 3, [ // shl_add
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 3>, SDTCisInt<0>, SDTCisInt<2>,
SDTCisInt<3>
]>;
def SDTIntSatNoShOp : SDTypeProfile<1, 2, [ // ssat with no shift
SDTCisSameAs<0, 1>, SDTCisInt<2>
]>;
Expand Down Expand Up @@ -411,6 +415,7 @@ def rotl : SDNode<"ISD::ROTL" , SDTIntShiftOp>;
def rotr : SDNode<"ISD::ROTR" , SDTIntShiftOp>;
def fshl : SDNode<"ISD::FSHL" , SDTIntShiftDOp>;
def fshr : SDNode<"ISD::FSHR" , SDTIntShiftDOp>;
def shl_add : SDNode<"ISD::SHL_ADD" , SDTIntShiftAddOp>;
def and : SDNode<"ISD::AND" , SDTIntBinOp,
[SDNPCommutative, SDNPAssociative]>;
def or : SDNode<"ISD::OR" , SDTIntBinOp,
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3521,6 +3521,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false,
Op->getFlags().hasExact());
break;
case ISD::SHL_ADD:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
Known = KnownBits::computeForAddSub(
true, false, false, KnownBits::shl(Known, Known2),
computeKnownBits(Op.getOperand(2), DemandedElts, Depth + 1));
break;
case ISD::FSHL:
case ISD::FSHR:
if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) {
Expand Down Expand Up @@ -7346,6 +7353,11 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
if (N1.getValueType() == VT)
return N1;
break;
case ISD::SHL_ADD:
assert(VT == N1.getValueType() && VT == N3.getValueType());
assert(TLI->isTypeLegal(VT) && "Created only post legalize");
assert(isa<ConstantSDNode>(N2) && "Constant shift expected");
break;
}

// Memoize node if it doesn't produce a glue result.
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::SRL: return "srl";
case ISD::ROTL: return "rotl";
case ISD::ROTR: return "rotr";
case ISD::SHL_ADD: return "shl_add";
case ISD::FSHL: return "fshl";
case ISD::FSHR: return "fshr";
case ISD::FADD: return "fadd";
Expand Down
18 changes: 10 additions & 8 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12789,10 +12789,9 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
SDLoc DL(N);
SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0);
SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0);
SDValue NA0 =
DAG.getNode(ISD::SHL, DL, VT, NL, DAG.getConstant(Diff, DL, VT));
SDValue NA1 = DAG.getNode(ISD::ADD, DL, VT, NA0, NS);
return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT));
SDValue SHADD =
DAG.getNode(ISD::SHL_ADD, DL, VT, NL, DAG.getConstant(Diff, DL, VT), NS);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use TargetConstant if its required to be a constant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed over in (#89263).

return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT));
}

// Combine a constant select operand into its use:
Expand Down Expand Up @@ -13028,14 +13027,17 @@ static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) {
N0.getOperand(0));
}

static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
static SDValue performADDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
if (SDValue V = combineAddOfBooleanXor(N, DAG))
return V;
if (SDValue V = transformAddImmMulImm(N, DAG, Subtarget))
return V;
if (SDValue V = transformAddShlImm(N, DAG, Subtarget))
return V;
if (!DCI.isBeforeLegalize())
if (SDValue V = transformAddShlImm(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
Expand Down Expand Up @@ -15894,7 +15896,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return V;
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
return V;
return performADDCombine(N, DAG, Subtarget);
return performADDCombine(N, DCI, Subtarget);
}
case ISD::SUB: {
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ foreach i = {1,2,3} in {
defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
def : Pat<(XLenVT (add_like_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)),
(shxadd GPR:$rs1, GPR:$rs2)>;
def : Pat<(XLenVT (shl_add GPR:$rs1, (XLenVT i), GPR:$rs2)),
(shxadd GPR:$rs1, GPR:$rs2)>;

defvar pat = !cast<ComplexPattern>("sh"#i#"add_op");
// More complex cases use a ComplexPattern.
Expand Down Expand Up @@ -881,6 +883,9 @@ foreach i = {1,2,3} in {
defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
def : Pat<(i32 (add_like_non_imm12 (shl GPR:$rs1, (i64 i)), GPR:$rs2)),
(shxadd GPR:$rs1, GPR:$rs2)>;
def : Pat<(i32 (shl_add GPR:$rs1, (i32 i), GPR:$rs2)),
(shxadd GPR:$rs1, GPR:$rs2)>;

}
}

Expand Down
40 changes: 38 additions & 2 deletions llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2519,7 +2519,6 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
if (N.getResNo() != 0) break;
[[fallthrough]];
case ISD::MUL:
case X86ISD::MUL_IMM:
// X*[3,5,9] -> X+X*[2,4,8]
if (AM.BaseType == X86ISelAddressMode::RegBase &&
AM.Base_Reg.getNode() == nullptr &&
Expand Down Expand Up @@ -2551,7 +2550,44 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
}
}
break;

case ISD::SHL_ADD: {
// X << [1,2,3] + Y (we should never create anything else)
auto *CN = cast<ConstantSDNode>(N.getOperand(1));
assert(CN->getZExtValue() == 1 || CN->getZExtValue() == 2 ||
CN->getZExtValue() == 3);
if (AM.BaseType == X86ISelAddressMode::RegBase &&
AM.Base_Reg.getNode() == nullptr && AM.IndexReg.getNode() == nullptr) {
AM.Scale = unsigned(2 << (CN->getZExtValue() - 1));

if (N.getOperand(0) == N.getOperand(2)) {
SDValue MulVal = N.getOperand(0);
SDValue Reg;

// Okay, we know that we have a scale by now. However, if the scaled
// value is an add of something and a constant, we can fold the
// constant into the disp field here.
if (MulVal.getNode()->getOpcode() == ISD::ADD &&
N->isOnlyUserOf(MulVal.getNode()) &&
isa<ConstantSDNode>(MulVal.getOperand(1))) {
Reg = MulVal.getOperand(0);
auto *AddVal = cast<ConstantSDNode>(MulVal.getOperand(1));
uint64_t Disp = AddVal->getSExtValue() * (AM.Scale + 1);
if (foldOffsetIntoAddress(Disp, AM))
Reg = N.getOperand(0);
} else {
Reg = N.getOperand(0);
}
AM.IndexReg = AM.Base_Reg = Reg;
return false;
}
// TODO: If N.getOperand(2) is a constant, we could try folding
// the displacement analogously to the above.
AM.IndexReg = N.getOperand(0);
AM.Base_Reg = N.getOperand(2);
return false;
}
break;
}
case ISD::SUB: {
// Given A-B, if A can be completely folded into the address and
// the index field with the index field unused, use -B as the index.
Expand Down
38 changes: 15 additions & 23 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33553,7 +33553,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(BZHI)
NODE_NAME_CASE(PDEP)
NODE_NAME_CASE(PEXT)
NODE_NAME_CASE(MUL_IMM)
NODE_NAME_CASE(MOVMSK)
NODE_NAME_CASE(PTEST)
NODE_NAME_CASE(TESTP)
Expand Down Expand Up @@ -36845,13 +36844,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
Known.resetAll();
switch (Opc) {
default: break;
case X86ISD::MUL_IMM: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known = KnownBits::mul(Known, Known2);
break;
}
case X86ISD::SETCC:
Known.Zero.setBitsFrom(1);
break;
Expand Down Expand Up @@ -46905,12 +46897,18 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi);
}

static SDValue createMulImm(uint64_t MulAmt, SDValue N, SelectionDAG &DAG,
EVT VT, const SDLoc &DL) {
assert(MulAmt == 3 || MulAmt == 5 || MulAmt == 9);
SDValue ShAmt = DAG.getConstant(Log2_64(MulAmt - 1), DL, MVT::i8);
return DAG.getNode(ISD::SHL_ADD, DL, VT, N, ShAmt, N);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to cause issue with poison? We've now increased the use count of N.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That... is a good question. We probably need to freeze here since we're increasing the number of uses, I had not considered that. Let me add the freeze and see if that influences codegen in practice. If it does, we may need to consider both a SHL_ADD node and a MUL359 node. I'm hoping we don't, let me investigate and report back.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the news here is not good. Adding in Freeze in the x86 backend code causes a whole bunch of regressions that were not obvious on first glance. Interestingly, incorporating the same logic into the RISC-V specific version of this patch (#89263) doesn't seem to expose the same kind of problems - mostly likely because the usage is much more isolated. #89290 fixes an analogous freeze issue in code already landed, again with no visible code diff.

I think what I'd like to suggest here is that we go ahead and focus review on #89263. Once we land that, I can iterate in tree on the RISC-V specific parts, and then rebase this patch on a fully fleshed through implementation and focus it on the x86 merge. (I clearly need to track something down there.)

(For the record, the issue @dtcxzyw flagged in the RISCV specific part of this patch doesn't exist in #89263 as I focused on a different subset there. That's probably confusing for reviewers in retrospect, sorry!)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I investigated these differences further. Net result is one fairly obvious missed optimization, one somewhat complicated but reasonable issue with COPY elimination, and one fundamental issue. I'm going to focus on only the last.

We end up with a situation where an inserted freeze gets hoisted through a chain of computation. This is all correct and fine, but as a side effect of that hoisting, we strip nsw off an add. The net result is that we can't prove a narrow addressing sequence is equivalent to the wider form, and thus fail to be able to fold a constant base offset into the addressing mode.

I'm a bit stuck on what to do about this case, and need to give this more thought.

}

static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG,
EVT VT, const SDLoc &DL) {

auto combineMulShlAddOrSub = [&](int Mult, int Shift, bool isAdd) {
SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
DAG.getConstant(Mult, DL, VT));
SDValue Result = createMulImm(Mult, N->getOperand(0), DAG, VT, DL);
Result = DAG.getNode(ISD::SHL, DL, VT, Result,
DAG.getConstant(Shift, DL, MVT::i8));
Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result,
Expand All @@ -46919,10 +46917,8 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG,
};

auto combineMulMulAddOrSub = [&](int Mul1, int Mul2, bool isAdd) {
SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
DAG.getConstant(Mul1, DL, VT));
Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, Result,
DAG.getConstant(Mul2, DL, VT));
SDValue Result = createMulImm(Mul1, N->getOperand(0), DAG, VT, DL);
Result = createMulImm(Mul2, Result, DAG, VT, DL);
Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result,
N->getOperand(0));
return Result;
Expand Down Expand Up @@ -46982,9 +46978,8 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG,
unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
DAG.getConstant(ShiftAmt, DL, MVT::i8));
SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
DAG.getConstant(ScaleShift, DL, MVT::i8));
return DAG.getNode(ISD::ADD, DL, VT, Shift1, Shift2);
return DAG.getNode(ISD::SHL_ADD, DL, VT, N->getOperand(0),
DAG.getConstant(ScaleShift, DL, MVT::i8), Shift1);
}
}

Expand Down Expand Up @@ -47204,8 +47199,7 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
SDValue NewMul = SDValue();
if (VT == MVT::i64 || VT == MVT::i32) {
if (AbsMulAmt == 3 || AbsMulAmt == 5 || AbsMulAmt == 9) {
NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
DAG.getConstant(AbsMulAmt, DL, VT));
NewMul = createMulImm(AbsMulAmt, N->getOperand(0), DAG, VT, DL);
if (SignMulAmt < 0)
NewMul =
DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), NewMul);
Expand Down Expand Up @@ -47243,15 +47237,13 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
DAG.getConstant(Log2_64(MulAmt1), DL, MVT::i8));
else
NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
DAG.getConstant(MulAmt1, DL, VT));
NewMul = createMulImm(MulAmt1, N->getOperand(0), DAG, VT, DL);

if (isPowerOf2_64(MulAmt2))
NewMul = DAG.getNode(ISD::SHL, DL, VT, NewMul,
DAG.getConstant(Log2_64(MulAmt2), DL, MVT::i8));
else
NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, NewMul,
DAG.getConstant(MulAmt2, DL, VT));
NewMul = NewMul = createMulImm(MulAmt2, NewMul, DAG, VT, DL);

// Negate the result.
if (SignMulAmt < 0)
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,6 @@ namespace llvm {
PDEP,
PEXT,

// X86-specific multiply by immediate.
MUL_IMM,

// Vector sign bit extraction.
MOVMSK,

Expand Down
8 changes: 3 additions & 5 deletions llvm/lib/Target/X86/X86InstrFragments.td
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@ def X86bzhi : SDNode<"X86ISD::BZHI", SDTIntBinOp>;
def X86pdep : SDNode<"X86ISD::PDEP", SDTIntBinOp>;
def X86pext : SDNode<"X86ISD::PEXT", SDTIntBinOp>;

def X86mul_imm : SDNode<"X86ISD::MUL_IMM", SDTIntBinOp>;

def X86DynAlloca : SDNode<"X86ISD::DYN_ALLOCA", SDT_X86DYN_ALLOCA,
[SDNPHasChain, SDNPOutGlue]>;

Expand Down Expand Up @@ -341,11 +339,11 @@ def X86cmpccxadd : SDNode<"X86ISD::CMPCCXADD", SDTX86Cmpccxadd,
// Define X86-specific addressing mode.
def addr : ComplexPattern<iPTR, 5, "selectAddr", [], [SDNPWantParent]>;
def lea32addr : ComplexPattern<i32, 5, "selectLEAAddr",
[add, sub, mul, X86mul_imm, shl, or, xor, frameindex],
[add, sub, mul, shl_add, shl, or, xor, frameindex],
[]>;
// In 64-bit mode 32-bit LEAs can use RIP-relative addressing.
def lea64_32addr : ComplexPattern<i32, 5, "selectLEA64_32Addr",
[add, sub, mul, X86mul_imm, shl, or, xor,
[add, sub, mul, shl_add, shl, or, xor,
frameindex, X86WrapperRIP],
[]>;

Expand All @@ -356,7 +354,7 @@ def tls32baseaddr : ComplexPattern<i32, 5, "selectTLSADDRAddr",
[tglobaltlsaddr], []>;

def lea64addr : ComplexPattern<i64, 5, "selectLEAAddr",
[add, sub, mul, X86mul_imm, shl, or, xor, frameindex,
[add, sub, mul, shl_add, shl, or, xor, frameindex,
X86WrapperRIP], []>;

def tls64addr : ComplexPattern<i64, 5, "selectTLSADDRAddr",
Expand Down
Loading