Skip to content

Commit 29afbd5

Browse files
authored
[RISCV] Add DAG combine to convert (iX ctpop (bitcast (vXi1 A))) into vcpop.m. (#117062)
This only handles the simplest case where vXi1 is a legal vector type. If the vector type isn't legal we need to go through type legalization, but the pattern gets much harder to recognize after that. Either because ctpop gets expanded due to Zbb not being enabled, or the bitcast becoming a bitcast+extractelt, or the ctpop being split into multiple ctpops and adds, etc.
1 parent a3e2f0a commit 29afbd5

File tree

4 files changed

+386
-69
lines changed

4 files changed

+386
-69
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

+51-1
Original file line numberDiff line numberDiff line change
@@ -1527,7 +1527,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15271527
ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
15281528
ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
15291529
ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
1530-
ISD::INSERT_VECTOR_ELT, ISD::ABS});
1530+
ISD::INSERT_VECTOR_ELT, ISD::ABS, ISD::CTPOP});
15311531
if (Subtarget.hasVendorXTHeadMemPair())
15321532
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
15331533
if (Subtarget.useRVVForFixedLengthVectors())
@@ -17055,6 +17055,52 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
1705517055
return Val;
1705617056
}
1705717057

17058+
// Convert
17059+
// (iX ctpop (bitcast (vXi1 A)))
17060+
// ->
17061+
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
17062+
// FIXME: It's complicated to match all the variations of this after type
17063+
// legalization so we only handle the pre-type legalization pattern, but that
17064+
// requires the fixed vector type to be legal.
17065+
static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
17066+
const RISCVSubtarget &Subtarget) {
17067+
EVT VT = N->getValueType(0);
17068+
if (!VT.isScalarInteger())
17069+
return SDValue();
17070+
17071+
SDValue Src = N->getOperand(0);
17072+
17073+
// Peek through zero_extend. It doesn't change the count.
17074+
if (Src.getOpcode() == ISD::ZERO_EXTEND)
17075+
Src = Src.getOperand(0);
17076+
17077+
if (Src.getOpcode() != ISD::BITCAST)
17078+
return SDValue();
17079+
17080+
Src = Src.getOperand(0);
17081+
EVT SrcEVT = Src.getValueType();
17082+
if (!SrcEVT.isSimple())
17083+
return SDValue();
17084+
17085+
MVT SrcMVT = SrcEVT.getSimpleVT();
17086+
// Make sure the input is an i1 vector.
17087+
if (!SrcMVT.isVector() || SrcMVT.getVectorElementType() != MVT::i1)
17088+
return SDValue();
17089+
17090+
if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
17091+
return SDValue();
17092+
17093+
MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
17094+
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
17095+
17096+
SDLoc DL(N);
17097+
auto [Mask, VL] = getDefaultVLOps(SrcMVT, ContainerVT, DL, DAG, Subtarget);
17098+
17099+
MVT XLenVT = Subtarget.getXLenVT();
17100+
SDValue Pop = DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Src, Mask, VL);
17101+
return DAG.getZExtOrTrunc(Pop, DL, VT);
17102+
}
17103+
1705817104
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1705917105
DAGCombinerInfo &DCI) const {
1706017106
SelectionDAG &DAG = DCI.DAG;
@@ -18023,6 +18069,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1802318069

1802418070
return SDValue();
1802518071
}
18072+
case ISD::CTPOP:
18073+
if (SDValue V = combineScalarCTPOPToVCPOP(N, DAG, Subtarget))
18074+
return V;
18075+
break;
1802618076
}
1802718077

1802818078
return SDValue();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zbb | FileCheck %s --check-prefixes=CHECK,RV32
3+
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zbb | FileCheck %s --check-prefixes=CHECK,RV64
4+
5+
define i2 @test_v2i1(<2 x i1> %x) {
6+
; CHECK-LABEL: test_v2i1:
7+
; CHECK: # %bb.0: # %entry
8+
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
9+
; CHECK-NEXT: vcpop.m a0, v0
10+
; CHECK-NEXT: ret
11+
entry:
12+
%a = bitcast <2 x i1> %x to i2
13+
%b = call i2 @llvm.ctpop.i2(i2 %a)
14+
ret i2 %b
15+
}
16+
17+
define i4 @test_v4i1(<4 x i1> %x) {
18+
; CHECK-LABEL: test_v4i1:
19+
; CHECK: # %bb.0: # %entry
20+
; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
21+
; CHECK-NEXT: vcpop.m a0, v0
22+
; CHECK-NEXT: ret
23+
entry:
24+
%a = bitcast <4 x i1> %x to i4
25+
%b = call i4 @llvm.ctpop.i4(i4 %a)
26+
ret i4 %b
27+
}
28+
29+
define i8 @test_v8i1(<8 x i1> %x) {
30+
; CHECK-LABEL: test_v8i1:
31+
; CHECK: # %bb.0: # %entry
32+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
33+
; CHECK-NEXT: vcpop.m a0, v0
34+
; CHECK-NEXT: ret
35+
entry:
36+
%a = bitcast <8 x i1> %x to i8
37+
%b = call i8 @llvm.ctpop.i8(i8 %a)
38+
ret i8 %b
39+
}
40+
41+
define i16 @test_v16i1(<16 x i1> %x) {
42+
; CHECK-LABEL: test_v16i1:
43+
; CHECK: # %bb.0: # %entry
44+
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
45+
; CHECK-NEXT: vcpop.m a0, v0
46+
; CHECK-NEXT: ret
47+
entry:
48+
%a = bitcast <16 x i1> %x to i16
49+
%b = call i16 @llvm.ctpop.i16(i16 %a)
50+
ret i16 %b
51+
}
52+
53+
define i32 @test_v32i1(<32 x i1> %x) {
54+
; CHECK-LABEL: test_v32i1:
55+
; CHECK: # %bb.0: # %entry
56+
; CHECK-NEXT: li a0, 32
57+
; CHECK-NEXT: vsetvli zero, a0, e8, m2, ta, ma
58+
; CHECK-NEXT: vcpop.m a0, v0
59+
; CHECK-NEXT: ret
60+
entry:
61+
%a = bitcast <32 x i1> %x to i32
62+
%b = call i32 @llvm.ctpop.i32(i32 %a)
63+
ret i32 %b
64+
}
65+
66+
define i64 @test_v64i1(<64 x i1> %x) {
67+
; RV32-LABEL: test_v64i1:
68+
; RV32: # %bb.0: # %entry
69+
; RV32-NEXT: li a0, 64
70+
; RV32-NEXT: vsetvli zero, a0, e8, m4, ta, ma
71+
; RV32-NEXT: vcpop.m a0, v0
72+
; RV32-NEXT: li a1, 0
73+
; RV32-NEXT: ret
74+
;
75+
; RV64-LABEL: test_v64i1:
76+
; RV64: # %bb.0: # %entry
77+
; RV64-NEXT: li a0, 64
78+
; RV64-NEXT: vsetvli zero, a0, e8, m4, ta, ma
79+
; RV64-NEXT: vcpop.m a0, v0
80+
; RV64-NEXT: ret
81+
entry:
82+
%a = bitcast <64 x i1> %x to i64
83+
%b = call i64 @llvm.ctpop.i64(i64 %a)
84+
ret i64 %b
85+
}
86+
87+
define i128 @test_v128i1(<128 x i1> %x) {
88+
; RV32-LABEL: test_v128i1:
89+
; RV32: # %bb.0: # %entry
90+
; RV32-NEXT: li a1, 128
91+
; RV32-NEXT: vsetvli zero, a1, e8, m8, ta, ma
92+
; RV32-NEXT: vcpop.m a1, v0
93+
; RV32-NEXT: sw a1, 0(a0)
94+
; RV32-NEXT: sw zero, 4(a0)
95+
; RV32-NEXT: sw zero, 8(a0)
96+
; RV32-NEXT: sw zero, 12(a0)
97+
; RV32-NEXT: ret
98+
;
99+
; RV64-LABEL: test_v128i1:
100+
; RV64: # %bb.0: # %entry
101+
; RV64-NEXT: li a0, 128
102+
; RV64-NEXT: vsetvli zero, a0, e8, m8, ta, ma
103+
; RV64-NEXT: vcpop.m a0, v0
104+
; RV64-NEXT: li a1, 0
105+
; RV64-NEXT: ret
106+
entry:
107+
%a = bitcast <128 x i1> %x to i128
108+
%b = call i128 @llvm.ctpop.i128(i128 %a)
109+
ret i128 %b
110+
}
111+
112+
define i32 @test_trunc_v128i1(<128 x i1> %x) {
113+
; CHECK-LABEL: test_trunc_v128i1:
114+
; CHECK: # %bb.0: # %entry
115+
; CHECK-NEXT: li a0, 128
116+
; CHECK-NEXT: vsetvli zero, a0, e8, m8, ta, ma
117+
; CHECK-NEXT: vcpop.m a0, v0
118+
; CHECK-NEXT: ret
119+
entry:
120+
%a = bitcast <128 x i1> %x to i128
121+
%b = call i128 @llvm.ctpop.i128(i128 %a)
122+
%c = trunc i128 %b to i32
123+
ret i32 %c
124+
}
125+
126+
define i256 @test_v256i1(<256 x i1> %x) {
127+
; RV32-LABEL: test_v256i1:
128+
; RV32: # %bb.0: # %entry
129+
; RV32-NEXT: vsetivli zero, 1, e64, m1, ta, ma
130+
; RV32-NEXT: vslidedown.vi v9, v0, 1
131+
; RV32-NEXT: li a1, 32
132+
; RV32-NEXT: vslidedown.vi v10, v8, 1
133+
; RV32-NEXT: vmv.x.s a2, v0
134+
; RV32-NEXT: vmv.x.s a3, v8
135+
; RV32-NEXT: vsrl.vx v11, v9, a1
136+
; RV32-NEXT: vsrl.vx v12, v0, a1
137+
; RV32-NEXT: vmv.x.s a4, v9
138+
; RV32-NEXT: vsrl.vx v9, v10, a1
139+
; RV32-NEXT: vsrl.vx v8, v8, a1
140+
; RV32-NEXT: vmv.x.s a1, v10
141+
; RV32-NEXT: cpop a3, a3
142+
; RV32-NEXT: cpop a2, a2
143+
; RV32-NEXT: vmv.x.s a5, v11
144+
; RV32-NEXT: vmv.x.s a6, v12
145+
; RV32-NEXT: vmv.x.s a7, v9
146+
; RV32-NEXT: vmv.x.s t0, v8
147+
; RV32-NEXT: cpop a1, a1
148+
; RV32-NEXT: cpop a4, a4
149+
; RV32-NEXT: cpop t0, t0
150+
; RV32-NEXT: cpop a7, a7
151+
; RV32-NEXT: cpop a6, a6
152+
; RV32-NEXT: cpop a5, a5
153+
; RV32-NEXT: add a3, a3, t0
154+
; RV32-NEXT: add a1, a1, a7
155+
; RV32-NEXT: add a2, a2, a6
156+
; RV32-NEXT: add a4, a4, a5
157+
; RV32-NEXT: add a5, a3, a1
158+
; RV32-NEXT: add a6, a2, a4
159+
; RV32-NEXT: add a1, a6, a5
160+
; RV32-NEXT: sltu a3, a5, a3
161+
; RV32-NEXT: sltu a4, a6, a2
162+
; RV32-NEXT: sltu a2, a1, a6
163+
; RV32-NEXT: add a3, a4, a3
164+
; RV32-NEXT: add a3, a3, a2
165+
; RV32-NEXT: beq a3, a4, .LBB8_2
166+
; RV32-NEXT: # %bb.1: # %entry
167+
; RV32-NEXT: sltu a2, a3, a4
168+
; RV32-NEXT: .LBB8_2: # %entry
169+
; RV32-NEXT: sw zero, 16(a0)
170+
; RV32-NEXT: sw zero, 20(a0)
171+
; RV32-NEXT: sw zero, 24(a0)
172+
; RV32-NEXT: sw zero, 28(a0)
173+
; RV32-NEXT: sw a1, 0(a0)
174+
; RV32-NEXT: sw a3, 4(a0)
175+
; RV32-NEXT: sw a2, 8(a0)
176+
; RV32-NEXT: sw zero, 12(a0)
177+
; RV32-NEXT: ret
178+
;
179+
; RV64-LABEL: test_v256i1:
180+
; RV64: # %bb.0: # %entry
181+
; RV64-NEXT: vsetivli zero, 1, e64, m1, ta, ma
182+
; RV64-NEXT: vslidedown.vi v9, v0, 1
183+
; RV64-NEXT: vmv.x.s a1, v0
184+
; RV64-NEXT: vslidedown.vi v10, v8, 1
185+
; RV64-NEXT: vmv.x.s a2, v8
186+
; RV64-NEXT: vmv.x.s a3, v9
187+
; RV64-NEXT: vmv.x.s a4, v10
188+
; RV64-NEXT: cpop a2, a2
189+
; RV64-NEXT: cpop a1, a1
190+
; RV64-NEXT: cpop a4, a4
191+
; RV64-NEXT: cpop a3, a3
192+
; RV64-NEXT: add a2, a2, a4
193+
; RV64-NEXT: add a1, a1, a3
194+
; RV64-NEXT: add a2, a1, a2
195+
; RV64-NEXT: sltu a1, a2, a1
196+
; RV64-NEXT: sd a2, 0(a0)
197+
; RV64-NEXT: sd a1, 8(a0)
198+
; RV64-NEXT: sd zero, 16(a0)
199+
; RV64-NEXT: sd zero, 24(a0)
200+
; RV64-NEXT: ret
201+
entry:
202+
%a = bitcast <256 x i1> %x to i256
203+
%b = call i256 @llvm.ctpop.i256(i256 %a)
204+
ret i256 %b
205+
}
206+
207+
define i32 @test_trunc_v256i1(<256 x i1> %x) {
208+
; RV32-LABEL: test_trunc_v256i1:
209+
; RV32: # %bb.0: # %entry
210+
; RV32-NEXT: vsetivli zero, 1, e64, m1, ta, ma
211+
; RV32-NEXT: vslidedown.vi v9, v0, 1
212+
; RV32-NEXT: li a0, 32
213+
; RV32-NEXT: vslidedown.vi v10, v8, 1
214+
; RV32-NEXT: vmv.x.s a1, v0
215+
; RV32-NEXT: vmv.x.s a2, v8
216+
; RV32-NEXT: vsrl.vx v11, v9, a0
217+
; RV32-NEXT: vsrl.vx v12, v0, a0
218+
; RV32-NEXT: vmv.x.s a3, v9
219+
; RV32-NEXT: vsrl.vx v9, v10, a0
220+
; RV32-NEXT: vsrl.vx v8, v8, a0
221+
; RV32-NEXT: vmv.x.s a0, v10
222+
; RV32-NEXT: cpop a2, a2
223+
; RV32-NEXT: cpop a1, a1
224+
; RV32-NEXT: vmv.x.s a4, v11
225+
; RV32-NEXT: vmv.x.s a5, v12
226+
; RV32-NEXT: vmv.x.s a6, v9
227+
; RV32-NEXT: vmv.x.s a7, v8
228+
; RV32-NEXT: cpop a0, a0
229+
; RV32-NEXT: cpop a3, a3
230+
; RV32-NEXT: cpop a7, a7
231+
; RV32-NEXT: cpop a6, a6
232+
; RV32-NEXT: cpop a5, a5
233+
; RV32-NEXT: cpop a4, a4
234+
; RV32-NEXT: add a2, a2, a7
235+
; RV32-NEXT: add a0, a0, a6
236+
; RV32-NEXT: add a1, a1, a5
237+
; RV32-NEXT: add a3, a3, a4
238+
; RV32-NEXT: add a0, a2, a0
239+
; RV32-NEXT: add a1, a1, a3
240+
; RV32-NEXT: add a0, a1, a0
241+
; RV32-NEXT: ret
242+
;
243+
; RV64-LABEL: test_trunc_v256i1:
244+
; RV64: # %bb.0: # %entry
245+
; RV64-NEXT: vsetivli zero, 1, e64, m1, ta, ma
246+
; RV64-NEXT: vslidedown.vi v9, v0, 1
247+
; RV64-NEXT: vmv.x.s a0, v0
248+
; RV64-NEXT: vslidedown.vi v10, v8, 1
249+
; RV64-NEXT: vmv.x.s a1, v8
250+
; RV64-NEXT: vmv.x.s a2, v9
251+
; RV64-NEXT: vmv.x.s a3, v10
252+
; RV64-NEXT: cpop a1, a1
253+
; RV64-NEXT: cpop a0, a0
254+
; RV64-NEXT: cpop a3, a3
255+
; RV64-NEXT: cpop a2, a2
256+
; RV64-NEXT: add a1, a1, a3
257+
; RV64-NEXT: add a0, a0, a2
258+
; RV64-NEXT: add a0, a0, a1
259+
; RV64-NEXT: ret
260+
entry:
261+
%a = bitcast <256 x i1> %x to i256
262+
%b = call i256 @llvm.ctpop.i256(i256 %a)
263+
%c = trunc i256 %b to i32
264+
ret i32 %c
265+
}

llvm/test/CodeGen/RISCV/rvv/compressstore.ll

+8-14
Original file line numberDiff line numberDiff line change
@@ -453,20 +453,17 @@ define void @test_compresstore_v128i16(ptr %p, <128 x i1> %mask, <128 x i16> %da
453453
; RV64-NEXT: vsetvli zero, a1, e16, m8, ta, ma
454454
; RV64-NEXT: vcompress.vm v24, v8, v0
455455
; RV64-NEXT: vcpop.m a2, v0
456-
; RV64-NEXT: vsetvli zero, a2, e16, m8, ta, ma
457-
; RV64-NEXT: vse16.v v24, (a0)
458456
; RV64-NEXT: vsetivli zero, 8, e8, m1, ta, ma
459457
; RV64-NEXT: vslidedown.vi v8, v0, 8
460-
; RV64-NEXT: vsetvli zero, zero, e64, m8, ta, ma
461-
; RV64-NEXT: vmv.x.s a2, v0
462458
; RV64-NEXT: vsetvli zero, a1, e16, m8, ta, ma
463-
; RV64-NEXT: vcompress.vm v24, v16, v8
459+
; RV64-NEXT: vcompress.vm v0, v16, v8
464460
; RV64-NEXT: vcpop.m a1, v8
465-
; RV64-NEXT: cpop a2, a2
461+
; RV64-NEXT: vsetvli zero, a2, e16, m8, ta, ma
462+
; RV64-NEXT: vse16.v v24, (a0)
466463
; RV64-NEXT: slli a2, a2, 1
467464
; RV64-NEXT: add a0, a0, a2
468465
; RV64-NEXT: vsetvli zero, a1, e16, m8, ta, ma
469-
; RV64-NEXT: vse16.v v24, (a0)
466+
; RV64-NEXT: vse16.v v0, (a0)
470467
; RV64-NEXT: ret
471468
;
472469
; RV32-LABEL: test_compresstore_v128i16:
@@ -673,20 +670,17 @@ define void @test_compresstore_v64i32(ptr %p, <64 x i1> %mask, <64 x i32> %data)
673670
; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
674671
; RV32-NEXT: vcompress.vm v24, v8, v0
675672
; RV32-NEXT: vcpop.m a2, v0
676-
; RV32-NEXT: vsetvli zero, a2, e32, m8, ta, ma
677-
; RV32-NEXT: vse32.v v24, (a0)
678673
; RV32-NEXT: vsetivli zero, 4, e8, mf2, ta, ma
679674
; RV32-NEXT: vslidedown.vi v8, v0, 4
680-
; RV32-NEXT: vsetvli zero, zero, e32, m2, ta, ma
681-
; RV32-NEXT: vmv.x.s a2, v0
682675
; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
683-
; RV32-NEXT: vcompress.vm v24, v16, v8
676+
; RV32-NEXT: vcompress.vm v0, v16, v8
684677
; RV32-NEXT: vcpop.m a1, v8
685-
; RV32-NEXT: cpop a2, a2
678+
; RV32-NEXT: vsetvli zero, a2, e32, m8, ta, ma
679+
; RV32-NEXT: vse32.v v24, (a0)
686680
; RV32-NEXT: slli a2, a2, 2
687681
; RV32-NEXT: add a0, a0, a2
688682
; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
689-
; RV32-NEXT: vse32.v v24, (a0)
683+
; RV32-NEXT: vse32.v v0, (a0)
690684
; RV32-NEXT: ret
691685
entry:
692686
tail call void @llvm.masked.compressstore.v64i32(<64 x i32> %data, ptr align 4 %p, <64 x i1> %mask)

0 commit comments

Comments
 (0)