@@ -645,6 +645,11 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
645
645
});
646
646
if (!yieldOperand)
647
647
return failure ();
648
+
649
+ // Notify the rewriter that the warp op is changing (see the comment on
650
+ // the WarpOpTransferRead pattern).
651
+ rewriter.startRootUpdate (warpOp);
652
+
648
653
Operation *elementWise = yieldOperand->get ().getDefiningOp ();
649
654
unsigned operandIndex = yieldOperand->getOperandNumber ();
650
655
Value distributedVal = warpOp.getResult (operandIndex);
@@ -683,6 +688,7 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
683
688
{newWarpOp.getResult (operandIndex).getType ()});
684
689
rewriter.replaceAllUsesWith (newWarpOp.getResult (operandIndex),
685
690
newOp->getResult (0 ));
691
+ rewriter.finalizeRootUpdate (warpOp);
686
692
return success ();
687
693
}
688
694
};
@@ -713,6 +719,9 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
713
719
auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue ());
714
720
if (!dense)
715
721
return failure ();
722
+ // Notify the rewriter that the warp op is changing (see the comment on
723
+ // the WarpOpTransferRead pattern).
724
+ rewriter.startRootUpdate (warpOp);
716
725
unsigned operandIndex = yieldOperand->getOperandNumber ();
717
726
Attribute scalarAttr = dense.getSplatValue <Attribute>();
718
727
auto newAttr = DenseElementsAttr::get (
@@ -721,6 +730,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
721
730
rewriter.setInsertionPointAfter (warpOp);
722
731
Value distConstant = rewriter.create <arith::ConstantOp>(loc, newAttr);
723
732
rewriter.replaceAllUsesWith (warpOp.getResult (operandIndex), distConstant);
733
+ rewriter.finalizeRootUpdate (warpOp);
724
734
return success ();
725
735
}
726
736
};
@@ -823,7 +833,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
823
833
OpBuilder::InsertionGuard g (rewriter);
824
834
WarpExecuteOnLane0Op newWarpOp = warpOp;
825
835
Value newMask = read .getMask ();
836
+ bool hasMask = false ;
826
837
if (read .getMask ()) {
838
+ hasMask = true ;
827
839
// TODO: Distribution of masked reads with non-trivial permutation maps
828
840
// requires the distribution of the mask to elementwise match the
829
841
// distribution of the permuted written vector. Currently the details
@@ -840,6 +852,16 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
840
852
newRetIndices);
841
853
newMask = newWarpOp.getResult (newRetIndices[0 ]);
842
854
distributedVal = newWarpOp.getResult (operandIndex);
855
+ } else {
856
+ // This pattern does not actually change the warp op directly. Instead it
857
+ // just rewrites a new transfer read (when not masked) outside of the warp
858
+ // op and replaces the correponding result. There are then follow up
859
+ // patterns to erase now dead results of the warp op. This erasure allows
860
+ // propagation to continue, but this pattern on its own never actually
861
+ // tells the pattern rewriter that the warp op "changed." Notify the
862
+ // rewriter here that the warp op is changing. Similar situations are
863
+ // noted in following patterns.
864
+ rewriter.startRootUpdate (warpOp);
843
865
}
844
866
845
867
rewriter.setInsertionPointAfter (newWarpOp);
@@ -849,9 +871,12 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
849
871
SmallVector<Value> delinearizedIds;
850
872
if (!delinearizeLaneId (rewriter, read .getLoc (), sequentialType.getShape (),
851
873
distributedType.getShape (), newWarpOp.getWarpSize (),
852
- newWarpOp.getLaneid (), delinearizedIds))
874
+ newWarpOp.getLaneid (), delinearizedIds)) {
875
+ if (!hasMask)
876
+ rewriter.cancelRootUpdate (warpOp);
853
877
return rewriter.notifyMatchFailure (
854
878
read , " cannot delinearize lane ID for distribution" );
879
+ }
855
880
assert (!delinearizedIds.empty () || map.getNumResults () == 0 );
856
881
857
882
for (auto it : llvm::zip_equal (indexMap.getResults (), map.getResults ())) {
@@ -890,10 +915,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
890
915
if (!llvm::all_of (newRead->getOperands (), [&](Value value) {
891
916
return (newRead.getMask () && value == newRead.getMask ()) ||
892
917
newWarpOp.isDefinedOutsideOfRegion (value);
893
- }))
918
+ })) {
919
+ if (!hasMask)
920
+ rewriter.cancelRootUpdate (warpOp);
894
921
return failure ();
922
+ }
895
923
896
924
rewriter.replaceAllUsesWith (distributedVal, newRead);
925
+ if (!hasMask)
926
+ rewriter.finalizeRootUpdate (warpOp);
897
927
return success ();
898
928
}
899
929
};
@@ -996,7 +1026,11 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
996
1026
}
997
1027
if (!valForwarded)
998
1028
return failure ();
1029
+ // Notify the rewriter that the warp op is changing (see the comment on
1030
+ // the WarpOpTransferRead pattern).
1031
+ rewriter.startRootUpdate (warpOp);
999
1032
rewriter.replaceAllUsesWith (warpOp.getResult (resultIndex), valForwarded);
1033
+ rewriter.finalizeRootUpdate (warpOp);
1000
1034
return success ();
1001
1035
}
1002
1036
};
@@ -1024,6 +1058,9 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1024
1058
if (vector::isBroadcastableTo (broadcastSrcType, destVecType) !=
1025
1059
vector::BroadcastableToResult::Success)
1026
1060
return failure ();
1061
+ // Notify the rewriter that the warp op is changing (see the comment on
1062
+ // the WarpOpTransferRead pattern).
1063
+ rewriter.startRootUpdate (warpOp);
1027
1064
SmallVector<size_t > newRetIndices;
1028
1065
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1029
1066
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
@@ -1032,6 +1069,7 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1032
1069
loc, destVecType, newWarpOp->getResult (newRetIndices[0 ]));
1033
1070
rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
1034
1071
broadcasted);
1072
+ rewriter.finalizeRootUpdate (warpOp);
1035
1073
return success ();
1036
1074
}
1037
1075
};
@@ -1046,6 +1084,7 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1046
1084
warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
1047
1085
if (!operand)
1048
1086
return failure ();
1087
+
1049
1088
auto oldCastOp = operand->get ().getDefiningOp <vector::ShapeCastOp>();
1050
1089
1051
1090
unsigned int operandNumber = operand->getOperandNumber ();
@@ -1133,6 +1172,10 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1133
1172
mask, " cannot delinearize lane ID for distribution" );
1134
1173
assert (!delinearizedIds.empty ());
1135
1174
1175
+ // Notify the rewriter that the warp op is changing (see the comment on
1176
+ // the WarpOpTransferRead pattern).
1177
+ rewriter.startRootUpdate (warpOp);
1178
+
1136
1179
AffineExpr s0, s1;
1137
1180
bindSymbols (rewriter.getContext (), s0, s1);
1138
1181
SmallVector<Value> newOperands;
@@ -1151,6 +1194,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1151
1194
auto newMask =
1152
1195
rewriter.create <vector::CreateMaskOp>(loc, distType, newOperands);
1153
1196
rewriter.replaceAllUsesWith (warpOp.getResult (operandIndex), newMask);
1197
+ rewriter.finalizeRootUpdate (warpOp);
1154
1198
return success ();
1155
1199
}
1156
1200
};
0 commit comments