@@ -52,6 +52,253 @@ static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
52
52
result.append (transpose.begin (), transpose.begin () + numTransposedDims);
53
53
}
54
54
55
+ // / Returns true if the lowering option is a vector shuffle based approach.
56
+ static bool isShuffleLike (VectorTransposeLowering lowering) {
57
+ return lowering == VectorTransposeLowering::Shuffle1D ||
58
+ lowering == VectorTransposeLowering::Shuffle16x16;
59
+ }
60
+
61
+ // / Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
62
+ // / shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
63
+ // / create the mask for `numBits` bits vector. The `numBits` have to be a
64
+ // / multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
65
+ // / 512, there should be 16 elements in the final result. It constructs the
66
+ // / below mask to get the unpack elements.
67
+ // / [0, 1, 16, 17,
68
+ // / 0+4, 1+4, 16+4, 17+4,
69
+ // / 0+8, 1+8, 16+8, 17+8,
70
+ // / 0+12, 1+12, 16+12, 17+12]
71
+ static SmallVector<int64_t >
72
+ getUnpackShufflePermFor128Lane (ArrayRef<int64_t > vals, int numBits) {
73
+ assert (numBits % 128 == 0 && " expected numBits is a multiple of 128" );
74
+ int numElem = numBits / 32 ;
75
+ SmallVector<int64_t > res;
76
+ for (int i = 0 ; i < numElem; i += 4 )
77
+ for (int64_t v : vals)
78
+ res.push_back (v + i);
79
+ return res;
80
+ }
81
+
82
+ // / Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
83
+ // / example, if it is targeting 512 bit vector, returns
84
+ // / vector.shuffle on v1, v2, [0, 1, 16, 17,
85
+ // / 0+4, 1+4, 16+4, 17+4,
86
+ // / 0+8, 1+8, 16+8, 17+8,
87
+ // / 0+12, 1+12, 16+12, 17+12].
88
+ static Value createUnpackLoPd (ImplicitLocOpBuilder &b, Value v1, Value v2,
89
+ int numBits) {
90
+ int numElem = numBits / 32 ;
91
+ return b.create <vector::ShuffleOp>(
92
+ v1, v2,
93
+ getUnpackShufflePermFor128Lane ({0 , 1 , numElem, numElem + 1 }, numBits));
94
+ }
95
+
96
+ // / Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
97
+ // / example, if it is targeting 512 bit vector, returns
98
+ // / vector.shuffle, v1, v2, [2, 3, 18, 19,
99
+ // / 2+4, 3+4, 18+4, 19+4,
100
+ // / 2+8, 3+8, 18+8, 19+8,
101
+ // / 2+12, 3+12, 18+12, 19+12].
102
+ static Value createUnpackHiPd (ImplicitLocOpBuilder &b, Value v1, Value v2,
103
+ int numBits) {
104
+ int numElem = numBits / 32 ;
105
+ return b.create <vector::ShuffleOp>(
106
+ v1, v2,
107
+ getUnpackShufflePermFor128Lane ({2 , 3 , numElem + 2 , numElem + 3 },
108
+ numBits));
109
+ }
110
+
111
+ // / Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
112
+ // / example, if it is targeting 512 bit vector, returns
113
+ // / vector.shuffle, v1, v2, [0, 16, 1, 17,
114
+ // / 0+4, 16+4, 1+4, 17+4,
115
+ // / 0+8, 16+8, 1+8, 17+8,
116
+ // / 0+12, 16+12, 1+12, 17+12].
117
+ static Value createUnpackLoPs (ImplicitLocOpBuilder &b, Value v1, Value v2,
118
+ int numBits) {
119
+ int numElem = numBits / 32 ;
120
+ auto shuffle = b.create <vector::ShuffleOp>(
121
+ v1, v2,
122
+ getUnpackShufflePermFor128Lane ({0 , numElem, 1 , numElem + 1 }, numBits));
123
+ return shuffle;
124
+ }
125
+
126
+ // / Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
127
+ // / example, if it is targeting 512 bit vector, returns
128
+ // / vector.shuffle, v1, v2, [2, 18, 3, 19,
129
+ // / 2+4, 18+4, 3+4, 19+4,
130
+ // / 2+8, 18+8, 3+8, 19+8,
131
+ // / 2+12, 18+12, 3+12, 19+12].
132
+ static Value createUnpackHiPs (ImplicitLocOpBuilder &b, Value v1, Value v2,
133
+ int numBits) {
134
+ int numElem = numBits / 32 ;
135
+ return b.create <vector::ShuffleOp>(
136
+ v1, v2,
137
+ getUnpackShufflePermFor128Lane ({2 , numElem + 2 , 3 , numElem + 3 },
138
+ numBits));
139
+ }
140
+
141
+ // / Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
142
+ // / elements) selected by `mask` from `v1` and `v2`. I.e.,
143
+ // /
144
+ // / DEFINE SELECT4(src, control) {
145
+ // / CASE(control[1:0]) OF
146
+ // / 0: tmp[127:0] := src[127:0]
147
+ // / 1: tmp[127:0] := src[255:128]
148
+ // / 2: tmp[127:0] := src[383:256]
149
+ // / 3: tmp[127:0] := src[511:384]
150
+ // / ESAC
151
+ // / RETURN tmp[127:0]
152
+ // / }
153
+ // / dst[127:0] := SELECT4(v1[511:0], mask[1:0])
154
+ // / dst[255:128] := SELECT4(v1[511:0], mask[3:2])
155
+ // / dst[383:256] := SELECT4(v2[511:0], mask[5:4])
156
+ // / dst[511:384] := SELECT4(v2[511:0], mask[7:6])
157
+ static Value create4x128BitSuffle (ImplicitLocOpBuilder &b, Value v1, Value v2,
158
+ uint8_t mask) {
159
+ assert (v1.getType ().cast <VectorType>().getShape ()[0 ] == 16 &&
160
+ " expected a vector with length=16" );
161
+ SmallVector<int64_t > shuffleMask;
162
+ auto appendToMask = [&](int64_t base, uint8_t control) {
163
+ switch (control) {
164
+ case 0 :
165
+ llvm::append_range (shuffleMask, ArrayRef<int64_t >{base + 0 , base + 1 ,
166
+ base + 2 , base + 3 });
167
+ break ;
168
+ case 1 :
169
+ llvm::append_range (shuffleMask, ArrayRef<int64_t >{base + 4 , base + 5 ,
170
+ base + 6 , base + 7 });
171
+ break ;
172
+ case 2 :
173
+ llvm::append_range (shuffleMask, ArrayRef<int64_t >{base + 8 , base + 9 ,
174
+ base + 10 , base + 11 });
175
+ break ;
176
+ case 3 :
177
+ llvm::append_range (shuffleMask, ArrayRef<int64_t >{base + 12 , base + 13 ,
178
+ base + 14 , base + 15 });
179
+ break ;
180
+ default :
181
+ llvm_unreachable (" control > 3 : overflow" );
182
+ }
183
+ };
184
+ uint8_t b01 = mask & 0x3 ;
185
+ uint8_t b23 = (mask >> 2 ) & 0x3 ;
186
+ uint8_t b45 = (mask >> 4 ) & 0x3 ;
187
+ uint8_t b67 = (mask >> 6 ) & 0x3 ;
188
+ appendToMask (0 , b01);
189
+ appendToMask (0 , b23);
190
+ appendToMask (16 , b45);
191
+ appendToMask (16 , b67);
192
+ return b.create <vector::ShuffleOp>(v1, v2, shuffleMask);
193
+ }
194
+
195
+ // / Lowers the value to a vector.shuffle op. The `source` is expected to be a
196
+ // / 1-D vector and have `m`x`n` elements.
197
+ static Value transposeToShuffle1D (OpBuilder &b, Value source, int m, int n) {
198
+ SmallVector<int64_t > mask;
199
+ mask.reserve (m * n);
200
+ for (int64_t j = 0 ; j < n; ++j)
201
+ for (int64_t i = 0 ; i < m; ++i)
202
+ mask.push_back (i * n + j);
203
+ return b.create <vector::ShuffleOp>(source.getLoc (), source, source, mask);
204
+ }
205
+
206
+ // / Lowers the value to a sequence of vector.shuffle ops. The `source` is
207
+ // / expected to be a 16x16 vector.
208
+ static Value transposeToShuffle16x16 (OpBuilder &builder, Value source, int m,
209
+ int n) {
210
+ ImplicitLocOpBuilder b (source.getLoc (), builder);
211
+ SmallVector<Value> vs;
212
+ for (int64_t i = 0 ; i < m; ++i)
213
+ vs.push_back (b.create <vector::ExtractOp>(source, i));
214
+
215
+ // Interleave 32-bit lanes using
216
+ // 8x _mm512_unpacklo_epi32
217
+ // 8x _mm512_unpackhi_epi32
218
+ Value t0 = createUnpackLoPs (b, vs[0x0 ], vs[0x1 ], 512 );
219
+ Value t1 = createUnpackHiPs (b, vs[0x0 ], vs[0x1 ], 512 );
220
+ Value t2 = createUnpackLoPs (b, vs[0x2 ], vs[0x3 ], 512 );
221
+ Value t3 = createUnpackHiPs (b, vs[0x2 ], vs[0x3 ], 512 );
222
+ Value t4 = createUnpackLoPs (b, vs[0x4 ], vs[0x5 ], 512 );
223
+ Value t5 = createUnpackHiPs (b, vs[0x4 ], vs[0x5 ], 512 );
224
+ Value t6 = createUnpackLoPs (b, vs[0x6 ], vs[0x7 ], 512 );
225
+ Value t7 = createUnpackHiPs (b, vs[0x6 ], vs[0x7 ], 512 );
226
+ Value t8 = createUnpackLoPs (b, vs[0x8 ], vs[0x9 ], 512 );
227
+ Value t9 = createUnpackHiPs (b, vs[0x8 ], vs[0x9 ], 512 );
228
+ Value ta = createUnpackLoPs (b, vs[0xa ], vs[0xb ], 512 );
229
+ Value tb = createUnpackHiPs (b, vs[0xa ], vs[0xb ], 512 );
230
+ Value tc = createUnpackLoPs (b, vs[0xc ], vs[0xd ], 512 );
231
+ Value td = createUnpackHiPs (b, vs[0xc ], vs[0xd ], 512 );
232
+ Value te = createUnpackLoPs (b, vs[0xe ], vs[0xf ], 512 );
233
+ Value tf = createUnpackHiPs (b, vs[0xe ], vs[0xf ], 512 );
234
+
235
+ // Interleave 64-bit lanes using
236
+ // 8x _mm512_unpacklo_epi64
237
+ // 8x _mm512_unpackhi_epi64
238
+ Value r0 = createUnpackLoPd (b, t0, t2, 512 );
239
+ Value r1 = createUnpackHiPd (b, t0, t2, 512 );
240
+ Value r2 = createUnpackLoPd (b, t1, t3, 512 );
241
+ Value r3 = createUnpackHiPd (b, t1, t3, 512 );
242
+ Value r4 = createUnpackLoPd (b, t4, t6, 512 );
243
+ Value r5 = createUnpackHiPd (b, t4, t6, 512 );
244
+ Value r6 = createUnpackLoPd (b, t5, t7, 512 );
245
+ Value r7 = createUnpackHiPd (b, t5, t7, 512 );
246
+ Value r8 = createUnpackLoPd (b, t8, ta, 512 );
247
+ Value r9 = createUnpackHiPd (b, t8, ta, 512 );
248
+ Value ra = createUnpackLoPd (b, t9, tb, 512 );
249
+ Value rb = createUnpackHiPd (b, t9, tb, 512 );
250
+ Value rc = createUnpackLoPd (b, tc, te, 512 );
251
+ Value rd = createUnpackHiPd (b, tc, te, 512 );
252
+ Value re = createUnpackLoPd (b, td, tf, 512 );
253
+ Value rf = createUnpackHiPd (b, td, tf, 512 );
254
+
255
+ // Permute 128-bit lanes using
256
+ // 16x _mm512_shuffle_i32x4
257
+ t0 = create4x128BitSuffle (b, r0, r4, 0x88 );
258
+ t1 = create4x128BitSuffle (b, r1, r5, 0x88 );
259
+ t2 = create4x128BitSuffle (b, r2, r6, 0x88 );
260
+ t3 = create4x128BitSuffle (b, r3, r7, 0x88 );
261
+ t4 = create4x128BitSuffle (b, r0, r4, 0xdd );
262
+ t5 = create4x128BitSuffle (b, r1, r5, 0xdd );
263
+ t6 = create4x128BitSuffle (b, r2, r6, 0xdd );
264
+ t7 = create4x128BitSuffle (b, r3, r7, 0xdd );
265
+ t8 = create4x128BitSuffle (b, r8, rc, 0x88 );
266
+ t9 = create4x128BitSuffle (b, r9, rd, 0x88 );
267
+ ta = create4x128BitSuffle (b, ra, re, 0x88 );
268
+ tb = create4x128BitSuffle (b, rb, rf, 0x88 );
269
+ tc = create4x128BitSuffle (b, r8, rc, 0xdd );
270
+ td = create4x128BitSuffle (b, r9, rd, 0xdd );
271
+ te = create4x128BitSuffle (b, ra, re, 0xdd );
272
+ tf = create4x128BitSuffle (b, rb, rf, 0xdd );
273
+
274
+ // Permute 256-bit lanes using again
275
+ // 16x _mm512_shuffle_i32x4
276
+ vs[0x0 ] = create4x128BitSuffle (b, t0, t8, 0x88 );
277
+ vs[0x1 ] = create4x128BitSuffle (b, t1, t9, 0x88 );
278
+ vs[0x2 ] = create4x128BitSuffle (b, t2, ta, 0x88 );
279
+ vs[0x3 ] = create4x128BitSuffle (b, t3, tb, 0x88 );
280
+ vs[0x4 ] = create4x128BitSuffle (b, t4, tc, 0x88 );
281
+ vs[0x5 ] = create4x128BitSuffle (b, t5, td, 0x88 );
282
+ vs[0x6 ] = create4x128BitSuffle (b, t6, te, 0x88 );
283
+ vs[0x7 ] = create4x128BitSuffle (b, t7, tf, 0x88 );
284
+ vs[0x8 ] = create4x128BitSuffle (b, t0, t8, 0xdd );
285
+ vs[0x9 ] = create4x128BitSuffle (b, t1, t9, 0xdd );
286
+ vs[0xa ] = create4x128BitSuffle (b, t2, ta, 0xdd );
287
+ vs[0xb ] = create4x128BitSuffle (b, t3, tb, 0xdd );
288
+ vs[0xc ] = create4x128BitSuffle (b, t4, tc, 0xdd );
289
+ vs[0xd ] = create4x128BitSuffle (b, t5, td, 0xdd );
290
+ vs[0xe ] = create4x128BitSuffle (b, t6, te, 0xdd );
291
+ vs[0xf ] = create4x128BitSuffle (b, t7, tf, 0xdd );
292
+
293
+ auto reshInputType = VectorType::get (
294
+ {m, n}, source.getType ().cast <VectorType>().getElementType ());
295
+ Value res =
296
+ b.create <arith::ConstantOp>(reshInputType, b.getZeroAttr (reshInputType));
297
+ for (int64_t i = 0 ; i < m; ++i)
298
+ res = b.create <vector::InsertOp>(vs[i], res, i);
299
+ return res;
300
+ }
301
+
55
302
namespace {
56
303
// / Progressive lowering of TransposeOp.
57
304
// / One:
@@ -84,8 +331,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
84
331
for (auto attr : op.getTransp ())
85
332
transp.push_back (attr.cast <IntegerAttr>().getInt ());
86
333
87
- if (vectorTransformOptions.vectorTransposeLowering ==
88
- vector::VectorTransposeLowering::Shuffle &&
334
+ if (isShuffleLike (vectorTransformOptions.vectorTransposeLowering ) &&
89
335
resType.getRank () == 2 && transp[0 ] == 1 && transp[1 ] == 0 )
90
336
return rewriter.notifyMatchFailure (
91
337
op, " Options specifies lowering to shuffle" );
@@ -145,10 +391,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
145
391
vector::VectorTransformsOptions vectorTransformOptions;
146
392
};
147
393
148
- // / Rewrite a 2-D vector.transpose as a sequence of:
394
+ // / Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
395
+ // / If the strategy is Shuffle1D, it will be lowered to:
149
396
// / vector.shape_cast 2D -> 1D
150
397
// / vector.shuffle
151
398
// / vector.shape_cast 1D -> 2D
399
+ // / If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
400
+ // / ops on 16xf32 vectors.
152
401
class TransposeOp2DToShuffleLowering
153
402
: public OpRewritePattern<vector::TransposeOp> {
154
403
public:
@@ -174,24 +423,28 @@ class TransposeOp2DToShuffleLowering
174
423
if (transp[0 ] != 1 && transp[1 ] != 0 )
175
424
return rewriter.notifyMatchFailure (op, " Not a 2D transpose permutation" );
176
425
177
- if (vectorTransformOptions.vectorTransposeLowering !=
178
- VectorTransposeLowering::Shuffle)
179
- return rewriter.notifyMatchFailure (op, " Options do not ask for Shuffle" );
180
-
426
+ Value res;
181
427
int64_t m = srcType.getShape ().front (), n = srcType.getShape ().back ();
182
- Value casted = rewriter.create <vector::ShapeCastOp>(
183
- loc, VectorType::get ({m * n}, srcType.getElementType ()),
184
- op.getVector ());
185
- SmallVector<int64_t > mask;
186
- mask.reserve (m * n);
187
- for (int64_t j = 0 ; j < n; ++j)
188
- for (int64_t i = 0 ; i < m; ++i)
189
- mask.push_back (i * n + j);
190
-
191
- Value shuffled =
192
- rewriter.create <vector::ShuffleOp>(loc, casted, casted, mask);
428
+ switch (vectorTransformOptions.vectorTransposeLowering ) {
429
+ case VectorTransposeLowering::Shuffle1D: {
430
+ Value casted = rewriter.create <vector::ShapeCastOp>(
431
+ loc, VectorType::get ({m * n}, srcType.getElementType ()),
432
+ op.getVector ());
433
+ res = transposeToShuffle1D (rewriter, casted, m, n);
434
+ break ;
435
+ }
436
+ case VectorTransposeLowering::Shuffle16x16:
437
+ if (m != 16 || n != 16 )
438
+ return failure ();
439
+ res = transposeToShuffle16x16 (rewriter, op.getVector (), m, n);
440
+ break ;
441
+ case VectorTransposeLowering::EltWise:
442
+ case VectorTransposeLowering::Flat:
443
+ return failure ();
444
+ }
445
+
193
446
rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(
194
- op, op.getResultVectorType (), shuffled );
447
+ op, op.getResultVectorType (), res );
195
448
196
449
return success ();
197
450
}
0 commit comments