Skip to content

Commit 8d163e5

Browse files
committed
[mlir][Vector] Add 16x16 strategy to vector.transpose lowering.
It adds a `shuffle_16x16` strategy LowerVectorTranspose and renames `shuffle` to `shuffle_1d`. The idea is similar to 8x8 cases in x86Vector::avx2. The general algorithm is: ``` interleave 32-bit lanes using 8x _mm512_unpacklo_epi32 8x _mm512_unpackhi_epi32 interleave 64-bit lanes using 8x _mm512_unpacklo_epi64 8x _mm512_unpackhi_epi64 permute 128-bit lanes using 16x _mm512_shuffle_i32x4 permute 256-bit lanes using again 16x _mm512_shuffle_i32x4 ``` After the first stage, they got transposed to ``` 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 32 48 33 49 ... 34 50 35 51 ... 64 80 65 81 ... ... ``` After the second stage, they got transposed to ``` 0 16 32 48 ... 1 17 33 49 ... 2 18 34 49 ... 3 19 35 51 ... 64 80 96 112 ... 65 81 97 114 ... 66 82 98 113 ... 67 83 99 115 ... ... ``` After the thrid stage, they got transposed to ``` 0 16 32 48 8 24 40 56 64 80 96 112 ... 1 17 33 49 ... 2 18 34 50 ... 3 19 35 51 ... 4 20 36 52 ... 5 21 37 53 ... 6 22 38 54 ... 7 23 39 55 ... 128 144 160 176 ... 129 145 161 177 ... ... ``` After the last stage, they got transposed to ``` 0 16 32 48 64 80 96 112 ... 240 1 17 33 49 66 81 97 113 ... 241 2 18 34 50 67 82 98 114 ... 242 ... 15 31 47 63 79 96 111 127 ... 255 ``` Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D148685
1 parent 1746c78 commit 8d163e5

File tree

6 files changed

+398
-26
lines changed

6 files changed

+398
-26
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td

+7-4
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@ def VectorTransposeLowering_Elementwise:
1818
// intrinsics.
1919
def VectorTransposeLowering_FlatTranspose:
2020
I32EnumAttrCase<"Flat", 1, "flat_transpose">;
21-
// Lower 2-D transpose to `vector.shuffle`.
22-
def VectorTransposeLowering_Shuffle:
23-
I32EnumAttrCase<"Shuffle", 2, "shuffle">;
21+
// Lower 2-D transpose to `vector.shuffle` on 1-D vector.
22+
def VectorTransposeLowering_Shuffle1D:
23+
I32EnumAttrCase<"Shuffle1D", 2, "shuffle_1d">;
24+
// Lower 2-D transpose to `vector.shuffle` on 16x16 vector.
25+
def VectorTransposeLowering_Shuffle16x16:
26+
I32EnumAttrCase<"Shuffle16x16", 3, "shuffle_16x16">;
2427
def VectorTransposeLoweringAttr : I32EnumAttr<
2528
"VectorTransposeLowering",
2629
"control the lowering of `vector.transpose` operations.",
2730
[VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
28-
VectorTransposeLowering_Shuffle]> {
31+
VectorTransposeLowering_Shuffle1D, VectorTransposeLowering_Shuffle16x16]> {
2932
let cppNamespace = "::mlir::vector";
3033
}
3134

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

+272-19
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,253 @@ static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
5252
result.append(transpose.begin(), transpose.begin() + numTransposedDims);
5353
}
5454

55+
/// Returns true if the lowering option is a vector shuffle based approach.
56+
static bool isShuffleLike(VectorTransposeLowering lowering) {
57+
return lowering == VectorTransposeLowering::Shuffle1D ||
58+
lowering == VectorTransposeLowering::Shuffle16x16;
59+
}
60+
61+
/// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
62+
/// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
63+
/// create the mask for `numBits` bits vector. The `numBits` have to be a
64+
/// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
65+
/// 512, there should be 16 elements in the final result. It constructs the
66+
/// below mask to get the unpack elements.
67+
/// [0, 1, 16, 17,
68+
/// 0+4, 1+4, 16+4, 17+4,
69+
/// 0+8, 1+8, 16+8, 17+8,
70+
/// 0+12, 1+12, 16+12, 17+12]
71+
static SmallVector<int64_t>
72+
getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
73+
assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
74+
int numElem = numBits / 32;
75+
SmallVector<int64_t> res;
76+
for (int i = 0; i < numElem; i += 4)
77+
for (int64_t v : vals)
78+
res.push_back(v + i);
79+
return res;
80+
}
81+
82+
/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
83+
/// example, if it is targeting 512 bit vector, returns
84+
/// vector.shuffle on v1, v2, [0, 1, 16, 17,
85+
/// 0+4, 1+4, 16+4, 17+4,
86+
/// 0+8, 1+8, 16+8, 17+8,
87+
/// 0+12, 1+12, 16+12, 17+12].
88+
static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
89+
int numBits) {
90+
int numElem = numBits / 32;
91+
return b.create<vector::ShuffleOp>(
92+
v1, v2,
93+
getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
94+
}
95+
96+
/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
97+
/// example, if it is targeting 512 bit vector, returns
98+
/// vector.shuffle, v1, v2, [2, 3, 18, 19,
99+
/// 2+4, 3+4, 18+4, 19+4,
100+
/// 2+8, 3+8, 18+8, 19+8,
101+
/// 2+12, 3+12, 18+12, 19+12].
102+
static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
103+
int numBits) {
104+
int numElem = numBits / 32;
105+
return b.create<vector::ShuffleOp>(
106+
v1, v2,
107+
getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
108+
numBits));
109+
}
110+
111+
/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
112+
/// example, if it is targeting 512 bit vector, returns
113+
/// vector.shuffle, v1, v2, [0, 16, 1, 17,
114+
/// 0+4, 16+4, 1+4, 17+4,
115+
/// 0+8, 16+8, 1+8, 17+8,
116+
/// 0+12, 16+12, 1+12, 17+12].
117+
static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
118+
int numBits) {
119+
int numElem = numBits / 32;
120+
auto shuffle = b.create<vector::ShuffleOp>(
121+
v1, v2,
122+
getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
123+
return shuffle;
124+
}
125+
126+
/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
127+
/// example, if it is targeting 512 bit vector, returns
128+
/// vector.shuffle, v1, v2, [2, 18, 3, 19,
129+
/// 2+4, 18+4, 3+4, 19+4,
130+
/// 2+8, 18+8, 3+8, 19+8,
131+
/// 2+12, 18+12, 3+12, 19+12].
132+
static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
133+
int numBits) {
134+
int numElem = numBits / 32;
135+
return b.create<vector::ShuffleOp>(
136+
v1, v2,
137+
getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
138+
numBits));
139+
}
140+
141+
/// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
142+
/// elements) selected by `mask` from `v1` and `v2`. I.e.,
143+
///
144+
/// DEFINE SELECT4(src, control) {
145+
/// CASE(control[1:0]) OF
146+
/// 0: tmp[127:0] := src[127:0]
147+
/// 1: tmp[127:0] := src[255:128]
148+
/// 2: tmp[127:0] := src[383:256]
149+
/// 3: tmp[127:0] := src[511:384]
150+
/// ESAC
151+
/// RETURN tmp[127:0]
152+
/// }
153+
/// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
154+
/// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
155+
/// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
156+
/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
157+
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
158+
uint8_t mask) {
159+
assert(v1.getType().cast<VectorType>().getShape()[0] == 16 &&
160+
"expected a vector with length=16");
161+
SmallVector<int64_t> shuffleMask;
162+
auto appendToMask = [&](int64_t base, uint8_t control) {
163+
switch (control) {
164+
case 0:
165+
llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
166+
base + 2, base + 3});
167+
break;
168+
case 1:
169+
llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
170+
base + 6, base + 7});
171+
break;
172+
case 2:
173+
llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
174+
base + 10, base + 11});
175+
break;
176+
case 3:
177+
llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
178+
base + 14, base + 15});
179+
break;
180+
default:
181+
llvm_unreachable("control > 3 : overflow");
182+
}
183+
};
184+
uint8_t b01 = mask & 0x3;
185+
uint8_t b23 = (mask >> 2) & 0x3;
186+
uint8_t b45 = (mask >> 4) & 0x3;
187+
uint8_t b67 = (mask >> 6) & 0x3;
188+
appendToMask(0, b01);
189+
appendToMask(0, b23);
190+
appendToMask(16, b45);
191+
appendToMask(16, b67);
192+
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
193+
}
194+
195+
/// Lowers the value to a vector.shuffle op. The `source` is expected to be a
196+
/// 1-D vector and have `m`x`n` elements.
197+
static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
198+
SmallVector<int64_t> mask;
199+
mask.reserve(m * n);
200+
for (int64_t j = 0; j < n; ++j)
201+
for (int64_t i = 0; i < m; ++i)
202+
mask.push_back(i * n + j);
203+
return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
204+
}
205+
206+
/// Lowers the value to a sequence of vector.shuffle ops. The `source` is
207+
/// expected to be a 16x16 vector.
208+
static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
209+
int n) {
210+
ImplicitLocOpBuilder b(source.getLoc(), builder);
211+
SmallVector<Value> vs;
212+
for (int64_t i = 0; i < m; ++i)
213+
vs.push_back(b.create<vector::ExtractOp>(source, i));
214+
215+
// Interleave 32-bit lanes using
216+
// 8x _mm512_unpacklo_epi32
217+
// 8x _mm512_unpackhi_epi32
218+
Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
219+
Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
220+
Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
221+
Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
222+
Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
223+
Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
224+
Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
225+
Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
226+
Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
227+
Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
228+
Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
229+
Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
230+
Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
231+
Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
232+
Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
233+
Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
234+
235+
// Interleave 64-bit lanes using
236+
// 8x _mm512_unpacklo_epi64
237+
// 8x _mm512_unpackhi_epi64
238+
Value r0 = createUnpackLoPd(b, t0, t2, 512);
239+
Value r1 = createUnpackHiPd(b, t0, t2, 512);
240+
Value r2 = createUnpackLoPd(b, t1, t3, 512);
241+
Value r3 = createUnpackHiPd(b, t1, t3, 512);
242+
Value r4 = createUnpackLoPd(b, t4, t6, 512);
243+
Value r5 = createUnpackHiPd(b, t4, t6, 512);
244+
Value r6 = createUnpackLoPd(b, t5, t7, 512);
245+
Value r7 = createUnpackHiPd(b, t5, t7, 512);
246+
Value r8 = createUnpackLoPd(b, t8, ta, 512);
247+
Value r9 = createUnpackHiPd(b, t8, ta, 512);
248+
Value ra = createUnpackLoPd(b, t9, tb, 512);
249+
Value rb = createUnpackHiPd(b, t9, tb, 512);
250+
Value rc = createUnpackLoPd(b, tc, te, 512);
251+
Value rd = createUnpackHiPd(b, tc, te, 512);
252+
Value re = createUnpackLoPd(b, td, tf, 512);
253+
Value rf = createUnpackHiPd(b, td, tf, 512);
254+
255+
// Permute 128-bit lanes using
256+
// 16x _mm512_shuffle_i32x4
257+
t0 = create4x128BitSuffle(b, r0, r4, 0x88);
258+
t1 = create4x128BitSuffle(b, r1, r5, 0x88);
259+
t2 = create4x128BitSuffle(b, r2, r6, 0x88);
260+
t3 = create4x128BitSuffle(b, r3, r7, 0x88);
261+
t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
262+
t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
263+
t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
264+
t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
265+
t8 = create4x128BitSuffle(b, r8, rc, 0x88);
266+
t9 = create4x128BitSuffle(b, r9, rd, 0x88);
267+
ta = create4x128BitSuffle(b, ra, re, 0x88);
268+
tb = create4x128BitSuffle(b, rb, rf, 0x88);
269+
tc = create4x128BitSuffle(b, r8, rc, 0xdd);
270+
td = create4x128BitSuffle(b, r9, rd, 0xdd);
271+
te = create4x128BitSuffle(b, ra, re, 0xdd);
272+
tf = create4x128BitSuffle(b, rb, rf, 0xdd);
273+
274+
// Permute 256-bit lanes using again
275+
// 16x _mm512_shuffle_i32x4
276+
vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
277+
vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
278+
vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
279+
vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
280+
vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
281+
vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
282+
vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
283+
vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
284+
vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
285+
vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
286+
vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
287+
vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
288+
vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
289+
vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
290+
vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
291+
vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
292+
293+
auto reshInputType = VectorType::get(
294+
{m, n}, source.getType().cast<VectorType>().getElementType());
295+
Value res =
296+
b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
297+
for (int64_t i = 0; i < m; ++i)
298+
res = b.create<vector::InsertOp>(vs[i], res, i);
299+
return res;
300+
}
301+
55302
namespace {
56303
/// Progressive lowering of TransposeOp.
57304
/// One:
@@ -84,8 +331,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
84331
for (auto attr : op.getTransp())
85332
transp.push_back(attr.cast<IntegerAttr>().getInt());
86333

87-
if (vectorTransformOptions.vectorTransposeLowering ==
88-
vector::VectorTransposeLowering::Shuffle &&
334+
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
89335
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
90336
return rewriter.notifyMatchFailure(
91337
op, "Options specifies lowering to shuffle");
@@ -145,10 +391,13 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
145391
vector::VectorTransformsOptions vectorTransformOptions;
146392
};
147393

148-
/// Rewrite a 2-D vector.transpose as a sequence of:
394+
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
395+
/// If the strategy is Shuffle1D, it will be lowered to:
149396
/// vector.shape_cast 2D -> 1D
150397
/// vector.shuffle
151398
/// vector.shape_cast 1D -> 2D
399+
/// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
400+
/// ops on 16xf32 vectors.
152401
class TransposeOp2DToShuffleLowering
153402
: public OpRewritePattern<vector::TransposeOp> {
154403
public:
@@ -174,24 +423,28 @@ class TransposeOp2DToShuffleLowering
174423
if (transp[0] != 1 && transp[1] != 0)
175424
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
176425

177-
if (vectorTransformOptions.vectorTransposeLowering !=
178-
VectorTransposeLowering::Shuffle)
179-
return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
180-
426+
Value res;
181427
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
182-
Value casted = rewriter.create<vector::ShapeCastOp>(
183-
loc, VectorType::get({m * n}, srcType.getElementType()),
184-
op.getVector());
185-
SmallVector<int64_t> mask;
186-
mask.reserve(m * n);
187-
for (int64_t j = 0; j < n; ++j)
188-
for (int64_t i = 0; i < m; ++i)
189-
mask.push_back(i * n + j);
190-
191-
Value shuffled =
192-
rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
428+
switch (vectorTransformOptions.vectorTransposeLowering) {
429+
case VectorTransposeLowering::Shuffle1D: {
430+
Value casted = rewriter.create<vector::ShapeCastOp>(
431+
loc, VectorType::get({m * n}, srcType.getElementType()),
432+
op.getVector());
433+
res = transposeToShuffle1D(rewriter, casted, m, n);
434+
break;
435+
}
436+
case VectorTransposeLowering::Shuffle16x16:
437+
if (m != 16 || n != 16)
438+
return failure();
439+
res = transposeToShuffle16x16(rewriter, op.getVector(), m, n);
440+
break;
441+
case VectorTransposeLowering::EltWise:
442+
case VectorTransposeLowering::Flat:
443+
return failure();
444+
}
445+
193446
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
194-
op, op.getResultVectorType(), shuffled);
447+
op, op.getResultVectorType(), res);
195448

196449
return success();
197450
}

mlir/test/Dialect/LLVM/transform-e2e.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ transform.sequence failures(propagate) {
5454
: (!pdl.operation) -> !pdl.operation
5555

5656
%func_8 = transform.vector.lower_transpose %func_7
57-
lowering_strategy = "shuffle"
57+
lowering_strategy = "shuffle_1d"
5858
: (!pdl.operation) -> !pdl.operation
5959
}

mlir/test/Dialect/Vector/transform-vector.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,6 @@ transform.sequence failures(propagate) {
5757
: (!pdl.operation) -> !pdl.operation
5858

5959
%func_8 = transform.vector.lower_transpose %func_7
60-
lowering_strategy = "shuffle"
60+
lowering_strategy = "shuffle_1d"
6161
: (!pdl.operation) -> !pdl.operation
6262
}

0 commit comments

Comments
 (0)