Skip to content

[WIP][RISCV] Support 3-argument associative add for transformAddShlImm #86883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12733,20 +12733,13 @@ static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,

// Optimize (add (shl x, c0), (shl y, c1)) ->
// (SLLI (SH*ADD x, y), c0), if c1-c0 equals to [1|2|3].
static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
static SDValue transformAddShlImm(SDValue N0, SDValue N1, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
// Perform this optimization only in the zba extension.
if (!Subtarget.hasStdExtZba())
return SDValue();

// Skip for vector types and larger types.
EVT VT = N->getValueType(0);
if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
return SDValue();
EVT VT = N0.getValueType();

// The two operand nodes must be SHL and have no other use.
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
if (N0->getOpcode() != ISD::SHL || N1->getOpcode() != ISD::SHL ||
!N0->hasOneUse() || !N1->hasOneUse())
return SDValue();
Expand All @@ -12768,7 +12761,6 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
return SDValue();

// Build nodes.
SDLoc DL(N);
SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0);
SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0);
SDValue NA0 =
Expand All @@ -12777,6 +12769,46 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT));
}

// Generalized form of above which looks through one level of add
// reassociation for opportunities.
static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
// Perform this optimization only in the zba extension.
if (!Subtarget.hasStdExtZba())
return SDValue();

// Skip for vector types and larger types.
EVT VT = N->getValueType(0);
if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen())
return SDValue();

// We're look for two SHL nodes in the add tree with all nodes
// involved having no other use.
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
if (N0->getOpcode() != ISD::SHL)
std::swap(N0, N1);

if (SDValue Res = transformAddShlImm(N0, N1, SDLoc(N), DAG, Subtarget))
return Res;

if (N0->getOpcode() != ISD::SHL || N1->getOpcode() != ISD::ADD ||
!N1->hasOneUse())
return SDValue();

// Allow reassociation for a 3-argument add
SDLoc DL(N1);
SDValue A = N1->getOperand(0);
SDValue B = N1->getOperand(1);
if (SDValue Res = transformAddShlImm(N0, A, SDLoc(N), DAG, Subtarget))
return DAG.getNode(ISD::ADD, DL, VT, Res, B);

if (SDValue Res = transformAddShlImm(N0, B, SDLoc(N), DAG, Subtarget))
return DAG.getNode(ISD::ADD, DL, VT, Res, A);

return SDValue();
}

// Combine a constant select operand into its use:
//
// (and (select cond, -1, c), x)
Expand Down
10 changes: 4 additions & 6 deletions llvm/test/CodeGen/RISCV/rv64zba.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1315,9 +1315,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
;
; RV64ZBA-LABEL: sh6_sh3_add2:
; RV64ZBA: # %bb.0: # %entry
; RV64ZBA-NEXT: slli a1, a1, 6
; RV64ZBA-NEXT: add a0, a1, a0
; RV64ZBA-NEXT: sh3add a0, a2, a0
; RV64ZBA-NEXT: sh3add a1, a1, a2
; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ret
entry:
%shl = shl i64 %z, 3
Expand Down Expand Up @@ -1360,9 +1359,8 @@ define i64 @sh6_sh3_add4(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
;
; RV64ZBA-LABEL: sh6_sh3_add4:
; RV64ZBA: # %bb.0: # %entry
; RV64ZBA-NEXT: slli a1, a1, 6
; RV64ZBA-NEXT: sh3add a0, a2, a0
; RV64ZBA-NEXT: add a0, a0, a1
; RV64ZBA-NEXT: sh3add a1, a1, a2
; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ret
entry:
%shl = shl i64 %z, 3
Expand Down
Loading