Skip to content

Commit aa2376a

Browse files
authored
[mlir][vector] Notify the rewriter when sinking out of warp ops (#71964)
A number of the warp distribution patterns work by rewriting a warp op in place by moving a contained op outside. This notifies the rewriter that the warp op is changing in this case.
1 parent 2e912a2 commit aa2376a

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

+46-2
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,11 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
645645
});
646646
if (!yieldOperand)
647647
return failure();
648+
649+
// Notify the rewriter that the warp op is changing (see the comment on
650+
// the WarpOpTransferRead pattern).
651+
rewriter.startRootUpdate(warpOp);
652+
648653
Operation *elementWise = yieldOperand->get().getDefiningOp();
649654
unsigned operandIndex = yieldOperand->getOperandNumber();
650655
Value distributedVal = warpOp.getResult(operandIndex);
@@ -683,6 +688,7 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
683688
{newWarpOp.getResult(operandIndex).getType()});
684689
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
685690
newOp->getResult(0));
691+
rewriter.finalizeRootUpdate(warpOp);
686692
return success();
687693
}
688694
};
@@ -713,6 +719,9 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
713719
auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
714720
if (!dense)
715721
return failure();
722+
// Notify the rewriter that the warp op is changing (see the comment on
723+
// the WarpOpTransferRead pattern).
724+
rewriter.startRootUpdate(warpOp);
716725
unsigned operandIndex = yieldOperand->getOperandNumber();
717726
Attribute scalarAttr = dense.getSplatValue<Attribute>();
718727
auto newAttr = DenseElementsAttr::get(
@@ -721,6 +730,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
721730
rewriter.setInsertionPointAfter(warpOp);
722731
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
723732
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
733+
rewriter.finalizeRootUpdate(warpOp);
724734
return success();
725735
}
726736
};
@@ -823,7 +833,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
823833
OpBuilder::InsertionGuard g(rewriter);
824834
WarpExecuteOnLane0Op newWarpOp = warpOp;
825835
Value newMask = read.getMask();
836+
bool hasMask = false;
826837
if (read.getMask()) {
838+
hasMask = true;
827839
// TODO: Distribution of masked reads with non-trivial permutation maps
828840
// requires the distribution of the mask to elementwise match the
829841
// distribution of the permuted written vector. Currently the details
@@ -840,6 +852,16 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
840852
newRetIndices);
841853
newMask = newWarpOp.getResult(newRetIndices[0]);
842854
distributedVal = newWarpOp.getResult(operandIndex);
855+
} else {
856+
// This pattern does not actually change the warp op directly. Instead it
857+
// just rewrites a new transfer read (when not masked) outside of the warp
858+
// op and replaces the correponding result. There are then follow up
859+
// patterns to erase now dead results of the warp op. This erasure allows
860+
// propagation to continue, but this pattern on its own never actually
861+
// tells the pattern rewriter that the warp op "changed." Notify the
862+
// rewriter here that the warp op is changing. Similar situations are
863+
// noted in following patterns.
864+
rewriter.startRootUpdate(warpOp);
843865
}
844866

845867
rewriter.setInsertionPointAfter(newWarpOp);
@@ -849,9 +871,12 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
849871
SmallVector<Value> delinearizedIds;
850872
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
851873
distributedType.getShape(), newWarpOp.getWarpSize(),
852-
newWarpOp.getLaneid(), delinearizedIds))
874+
newWarpOp.getLaneid(), delinearizedIds)) {
875+
if (!hasMask)
876+
rewriter.cancelRootUpdate(warpOp);
853877
return rewriter.notifyMatchFailure(
854878
read, "cannot delinearize lane ID for distribution");
879+
}
855880
assert(!delinearizedIds.empty() || map.getNumResults() == 0);
856881

857882
for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
@@ -890,10 +915,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
890915
if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
891916
return (newRead.getMask() && value == newRead.getMask()) ||
892917
newWarpOp.isDefinedOutsideOfRegion(value);
893-
}))
918+
})) {
919+
if (!hasMask)
920+
rewriter.cancelRootUpdate(warpOp);
894921
return failure();
922+
}
895923

896924
rewriter.replaceAllUsesWith(distributedVal, newRead);
925+
if (!hasMask)
926+
rewriter.finalizeRootUpdate(warpOp);
897927
return success();
898928
}
899929
};
@@ -996,7 +1026,11 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
9961026
}
9971027
if (!valForwarded)
9981028
return failure();
1029+
// Notify the rewriter that the warp op is changing (see the comment on
1030+
// the WarpOpTransferRead pattern).
1031+
rewriter.startRootUpdate(warpOp);
9991032
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1033+
rewriter.finalizeRootUpdate(warpOp);
10001034
return success();
10011035
}
10021036
};
@@ -1024,6 +1058,9 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
10241058
if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
10251059
vector::BroadcastableToResult::Success)
10261060
return failure();
1061+
// Notify the rewriter that the warp op is changing (see the comment on
1062+
// the WarpOpTransferRead pattern).
1063+
rewriter.startRootUpdate(warpOp);
10271064
SmallVector<size_t> newRetIndices;
10281065
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
10291066
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
@@ -1032,6 +1069,7 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
10321069
loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
10331070
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
10341071
broadcasted);
1072+
rewriter.finalizeRootUpdate(warpOp);
10351073
return success();
10361074
}
10371075
};
@@ -1046,6 +1084,7 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
10461084
warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
10471085
if (!operand)
10481086
return failure();
1087+
10491088
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
10501089

10511090
unsigned int operandNumber = operand->getOperandNumber();
@@ -1133,6 +1172,10 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
11331172
mask, "cannot delinearize lane ID for distribution");
11341173
assert(!delinearizedIds.empty());
11351174

1175+
// Notify the rewriter that the warp op is changing (see the comment on
1176+
// the WarpOpTransferRead pattern).
1177+
rewriter.startRootUpdate(warpOp);
1178+
11361179
AffineExpr s0, s1;
11371180
bindSymbols(rewriter.getContext(), s0, s1);
11381181
SmallVector<Value> newOperands;
@@ -1151,6 +1194,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
11511194
auto newMask =
11521195
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
11531196
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1197+
rewriter.finalizeRootUpdate(warpOp);
11541198
return success();
11551199
}
11561200
};

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

+20
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,26 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
13511351

13521352
// -----
13531353

1354+
func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
1355+
%f0 = arith.constant 0.000000e+00 : f32
1356+
%c0 = arith.constant 0 : index
1357+
%r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>, vector<2xf32>) {
1358+
%mask = vector.create_mask %mask_ub: vector<128xi1>
1359+
%0 = vector.transfer_read %src[%c0, %index], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
1360+
%1 = vector.transfer_read %src[%c0, %index2], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
1361+
vector.yield %0, %1 : vector<128xf32>, vector<128xf32>
1362+
}
1363+
return %r#0, %r#1 : vector<2xf32>, vector<2xf32>
1364+
}
1365+
1366+
// CHECK-PROP-LABEL: func.func @warp_propagate_masked_transfer_read_shared_mask
1367+
// CHECK-PROP: vector.create_mask %{{.*}} : vector<2xi1>
1368+
// CHECK-PROP: vector.transfer_read %{{.*}} : memref<4096x4096xf32>, vector<2xf32>
1369+
// CHECK-PROP: vector.create_mask %{{.*}} : vector<2xi1>
1370+
// CHECK-PROP: vector.transfer_read %{{.*}} : memref<4096x4096xf32>, vector<2xf32>
1371+
1372+
// -----
1373+
13541374
func.func @warp_propagate_unconnected_read_write(%laneid: index, %buffer: memref<128xf32>, %f1: f32) -> (vector<2xf32>, vector<4xf32>) {
13551375
%f0 = arith.constant 0.000000e+00 : f32
13561376
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)