Skip to content

Commit f05d4af

Browse files
committed
mir-opt: Merge all branch BBs into a single copy statement for enum
1 parent 4fccb1d commit f05d4af

14 files changed

+563
-78
lines changed

compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ declare_passes! {
157157
mod lower_intrinsics : LowerIntrinsics;
158158
mod lower_slice_len : LowerSliceLenCalls;
159159
mod match_branches : MatchBranchSimplification;
160+
mod merge_branches: MergeBranchSimplification;
160161
mod mentioned_items : MentionedItems;
161162
mod multiple_return_terminators : MultipleReturnTerminators;
162163
mod nrvo : RenameReturnPlace;
@@ -707,6 +708,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
707708
&dead_store_elimination::DeadStoreElimination::Initial,
708709
&gvn::GVN,
709710
&simplify::SimplifyLocals::AfterGVN,
711+
&merge_branches::MergeBranchSimplification,
710712
&dataflow_const_prop::DataflowConstProp,
711713
&single_use_consts::SingleUseConsts,
712714
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
//! This pass attempts to merge all branches to eliminate switch terminator.
2+
//! Ideally, we could combine it with `MatchBranchSimplification`, as these two passes
3+
//! match and merge statements with different patterns. Given the compile time and
4+
//! code complexity, we have not merged them into a more general pass for now.
5+
use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
6+
use rustc_index::bit_set::DenseBitSet;
7+
use rustc_middle::mir::*;
8+
use rustc_middle::ty::util::Discr;
9+
use rustc_middle::ty::{self, TyCtxt};
10+
use rustc_mir_dataflow::impls::borrowed_locals;
11+
12+
use crate::dead_store_elimination::DeadStoreAnalysis;
13+
use crate::patch::MirPatch;
14+
15+
pub(super) struct MergeBranchSimplification;
16+
17+
impl<'tcx> crate::MirPass<'tcx> for MergeBranchSimplification {
18+
fn is_required(&self) -> bool {
19+
false
20+
}
21+
22+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
23+
sess.mir_opt_level() >= 2
24+
}
25+
26+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
27+
let typing_env = body.typing_env(tcx);
28+
let borrowed_locals = borrowed_locals(body);
29+
let mut dead_store_analysis = DeadStoreAnalysis::new(tcx, body, &borrowed_locals);
30+
31+
for switch_bb_idx in body.basic_blocks.indices() {
32+
let bbs = &*body.basic_blocks;
33+
let Some((switch_discr, targets)) = bbs[switch_bb_idx].terminator().kind.as_switch()
34+
else {
35+
continue;
36+
};
37+
// Check that destinations are identical, and if not, then don't optimize this block.
38+
let mut targets_iter = targets.iter();
39+
let first_terminator_kind = &bbs[targets_iter.next().unwrap().1].terminator().kind;
40+
if targets_iter.any(|(_, other_target)| {
41+
first_terminator_kind != &bbs[other_target].terminator().kind
42+
}) {
43+
continue;
44+
}
45+
// We require that the possible target blocks all be distinct.
46+
if !targets.is_distinct() {
47+
continue;
48+
}
49+
if !bbs[targets.otherwise()].is_empty_unreachable() {
50+
continue;
51+
}
52+
// Check if the copy source matches the following pattern.
53+
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
54+
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
55+
let Some(&Statement {
56+
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(src_place))),
57+
..
58+
}) = bbs[switch_bb_idx].statements.last()
59+
else {
60+
continue;
61+
};
62+
if switch_discr.place() != Some(discr_place) {
63+
continue;
64+
}
65+
let src_ty = src_place.ty(body.local_decls(), tcx);
66+
if let Some(dest_place) = can_simplify_to_copy(
67+
tcx,
68+
typing_env,
69+
body,
70+
targets,
71+
src_place,
72+
src_ty,
73+
&mut dead_store_analysis,
74+
) {
75+
let statement_index = bbs[switch_bb_idx].statements.len();
76+
let parent_end = Location { block: switch_bb_idx, statement_index };
77+
let mut patch = MirPatch::new(body);
78+
patch.add_assign(parent_end, dest_place, Rvalue::Use(Operand::Copy(src_place)));
79+
patch.patch_terminator(switch_bb_idx, first_terminator_kind.clone());
80+
patch.apply(body);
81+
super::simplify::remove_dead_blocks(body);
82+
// After modifying the MIR, the result of `MaybeTransitiveLiveLocals` may become invalid,
83+
// keeping it simple to process only once.
84+
break;
85+
}
86+
}
87+
}
88+
}
89+
90+
/// The GVN simplified
91+
/// ```ignore (syntax-highlighting-only)
92+
/// match a {
93+
/// Foo::A(x) => Foo::A(*x),
94+
/// Foo::B => Foo::B
95+
/// }
96+
/// ```
97+
/// to
98+
/// ```ignore (syntax-highlighting-only)
99+
/// match a {
100+
/// Foo::A(_x) => a, // copy a
101+
/// Foo::B => Foo::B
102+
/// }
103+
/// ```
104+
/// This function answers whether it can be simplified to a copy statement
105+
/// by returning the copy destination.
106+
fn can_simplify_to_copy<'tcx>(
107+
tcx: TyCtxt<'tcx>,
108+
typing_env: ty::TypingEnv<'tcx>,
109+
body: &Body<'tcx>,
110+
targets: &SwitchTargets,
111+
src_place: Place<'tcx>,
112+
src_ty: PlaceTy<'tcx>,
113+
dead_store_analysis: &mut DeadStoreAnalysis<'tcx, '_, '_>,
114+
) -> Option<Place<'tcx>> {
115+
let mut targets_iter = targets.iter();
116+
let (first_index, first_target) = targets_iter.next()?;
117+
let dest_place = find_copy_assign(
118+
tcx,
119+
typing_env,
120+
body,
121+
first_index,
122+
first_target,
123+
src_place,
124+
src_ty,
125+
dead_store_analysis,
126+
)?;
127+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
128+
if dest_ty.ty != src_ty.ty {
129+
return None;
130+
}
131+
for (other_index, other_target) in targets_iter {
132+
if dest_place
133+
!= find_copy_assign(
134+
tcx,
135+
typing_env,
136+
body,
137+
other_index,
138+
other_target,
139+
src_place,
140+
src_ty,
141+
dead_store_analysis,
142+
)?
143+
{
144+
return None;
145+
}
146+
}
147+
Some(dest_place)
148+
}
149+
150+
// Find the single assignment statement where the source of the copy is from the source.
151+
// All other statements are dead statements or have no effect that can be eliminated.
152+
fn find_copy_assign<'tcx>(
153+
tcx: TyCtxt<'tcx>,
154+
typing_env: ty::TypingEnv<'tcx>,
155+
body: &Body<'tcx>,
156+
index: u128,
157+
target_block: BasicBlock,
158+
src_place: Place<'tcx>,
159+
src_ty: PlaceTy<'tcx>,
160+
dead_store_analysis: &mut DeadStoreAnalysis<'tcx, '_, '_>,
161+
) -> Option<Place<'tcx>> {
162+
let statements = &body.basic_blocks[target_block].statements;
163+
if statements.is_empty() {
164+
return None;
165+
}
166+
let assign_stmt = if statements.len() == 1 {
167+
0
168+
} else {
169+
let mut lived_stmts: DenseBitSet<usize> = DenseBitSet::new_filled(statements.len());
170+
let mut expected_assign_stmt = None;
171+
for (statement_index, statement) in statements.iter().enumerate().rev() {
172+
let loc = Location { block: target_block, statement_index };
173+
if dead_store_analysis.is_dead_store(loc, &statement.kind) {
174+
lived_stmts.remove(statement_index);
175+
} else if matches!(
176+
statement.kind,
177+
StatementKind::StorageLive(_) | StatementKind::StorageDead(_)
178+
) {
179+
} else if matches!(statement.kind, StatementKind::Assign(_))
180+
&& expected_assign_stmt.is_none()
181+
{
182+
// There is only one assign statement that cannot be ignored
183+
// that can be used as an expected copy statement.
184+
expected_assign_stmt = Some(statement_index);
185+
lived_stmts.remove(statement_index);
186+
} else {
187+
return None;
188+
}
189+
}
190+
let expected_assign = expected_assign_stmt?;
191+
if !lived_stmts.is_empty() {
192+
// We can ignore the paired StorageLive and StorageDead.
193+
let mut storage_live_locals: DenseBitSet<Local> =
194+
DenseBitSet::new_empty(body.local_decls.len());
195+
for stmt_index in lived_stmts.iter() {
196+
let statement = &statements[stmt_index];
197+
match &statement.kind {
198+
StatementKind::StorageLive(local) if storage_live_locals.insert(*local) => {}
199+
StatementKind::StorageDead(local) if storage_live_locals.remove(*local) => {}
200+
_ => return None,
201+
}
202+
}
203+
if !storage_live_locals.is_empty() {
204+
return None;
205+
}
206+
}
207+
expected_assign
208+
};
209+
let &(dest_place, ref rvalue) = statements[assign_stmt].kind.as_assign()?;
210+
let dest_ty = dest_place.ty(body.local_decls(), tcx);
211+
if dest_ty.ty != src_ty.ty {
212+
return None;
213+
}
214+
let ty::Adt(def, _) = dest_ty.ty.kind() else {
215+
return None;
216+
};
217+
match rvalue {
218+
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
219+
Rvalue::Use(Operand::Constant(box constant))
220+
if let Const::Val(const_, ty) = constant.const_ =>
221+
{
222+
let (ecx, op) =
223+
mk_eval_cx_for_const_val(tcx.at(constant.span), typing_env, const_, ty)?;
224+
let variant = ecx.read_discriminant(&op).discard_err()?;
225+
if !def.variants()[variant].fields.is_empty() {
226+
return None;
227+
}
228+
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
229+
if val != index {
230+
return None;
231+
}
232+
}
233+
Rvalue::Use(Operand::Copy(place)) if *place == src_place => {}
234+
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
235+
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
236+
if fields.is_empty()
237+
&& let Some(Discr { val, .. }) =
238+
src_ty.ty.discriminant_for_variant(tcx, *variant_index)
239+
&& val == index => {}
240+
_ => return None,
241+
}
242+
Some(dest_place)
243+
}

tests/codegen/match-optimizes-away.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//@ compile-flags: -Copt-level=3 -Zmerge-functions=disabled
1+
//@ compile-flags: -Copt-level=3 -Cno-prepopulate-passes
2+
23
#![crate_type = "lib"]
34

45
pub enum Three {

tests/codegen/try_question_mark_nop.rs

+6-22
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,9 @@ use std::ptr::NonNull;
1616
#[no_mangle]
1717
pub fn option_nop_match_32(x: Option<u32>) -> Option<u32> {
1818
// CHECK: start:
19-
// CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i32 %0 to i1
20-
21-
// NINETEEN-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i32 %0, i32 0
22-
// NINETEEN-NEXT: [[REG2:%.*]] = insertvalue { i32, i32 } poison, i32 [[SELECT]], 0
23-
// NINETEEN-NEXT: [[REG3:%.*]] = insertvalue { i32, i32 } [[REG2]], i32 %1, 1
24-
25-
// TWENTY-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i32 %1, i32 undef
26-
// TWENTY-NEXT: [[REG2:%.*]] = insertvalue { i32, i32 } poison, i32 %0, 0
27-
// TWENTY-NEXT: [[REG3:%.*]] = insertvalue { i32, i32 } [[REG2]], i32 [[SELECT]], 1
28-
29-
// CHECK-NEXT: ret { i32, i32 } [[REG3]]
19+
// CHECK-NEXT: [[REG0:%.*]] = insertvalue { i32, i32 } poison, i32 %x.0, 0
20+
// CHECK-NEXT: [[REG1:%.*]] = insertvalue { i32, i32 } [[REG0]], i32 %x.1, 1
21+
// CHECK-NEXT: ret { i32, i32 } [[REG1]]
3022
match x {
3123
Some(x) => Some(x),
3224
None => None,
@@ -95,17 +87,9 @@ pub fn control_flow_nop_traits_32(x: ControlFlow<i32, u32>) -> ControlFlow<i32,
9587
#[no_mangle]
9688
pub fn option_nop_match_64(x: Option<u64>) -> Option<u64> {
9789
// CHECK: start:
98-
// CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i64 %0 to i1
99-
100-
// NINETEEN-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i64 %0, i64 0
101-
// NINETEEN-NEXT: [[REG2:%.*]] = insertvalue { i64, i64 } poison, i64 [[SELECT]], 0
102-
// NINETEEN-NEXT: [[REG3:%.*]] = insertvalue { i64, i64 } [[REG2]], i64 %1, 1
103-
104-
// TWENTY-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i64 %1, i64 undef
105-
// TWENTY-NEXT: [[REG2:%.*]] = insertvalue { i64, i64 } poison, i64 %0, 0
106-
// TWENTY-NEXT: [[REG3:%.*]] = insertvalue { i64, i64 } [[REG2]], i64 [[SELECT]], 1
107-
108-
// CHECK-NEXT: ret { i64, i64 } [[REG3]]
90+
// CHECK-NEXT: [[REG0:%.*]] = insertvalue { i64, i64 } poison, i64 %x.0, 0
91+
// CHECK-NEXT: [[REG1:%.*]] = insertvalue { i64, i64 } [[REG0]], i64 %x.1, 1
92+
// CHECK-NEXT: ret { i64, i64 } [[REG1]]
10993
match x {
11094
Some(x) => Some(x),
11195
None => None,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
- // MIR for `no_fields` before MergeBranchSimplification
2+
+ // MIR for `no_fields` after MergeBranchSimplification
3+
4+
fn no_fields(_1: NoFields) -> NoFields {
5+
debug a => _1;
6+
let mut _0: NoFields;
7+
let mut _2: isize;
8+
9+
bb0: {
10+
_2 = discriminant(_1);
11+
- switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
12+
+ _0 = copy _1;
13+
+ goto -> bb1;
14+
}
15+
16+
bb1: {
17+
- unreachable;
18+
- }
19+
-
20+
- bb2: {
21+
- _0 = NoFields::B;
22+
- goto -> bb4;
23+
- }
24+
-
25+
- bb3: {
26+
- _0 = NoFields::A;
27+
- goto -> bb4;
28+
- }
29+
-
30+
- bb4: {
31+
return;
32+
}
33+
}
34+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
- // MIR for `no_fields_failed` before MergeBranchSimplification
2+
+ // MIR for `no_fields_failed` after MergeBranchSimplification
3+
4+
fn no_fields_failed(_1: NoFields) -> NoFields {
5+
debug a => _1;
6+
let mut _0: NoFields;
7+
let mut _2: isize;
8+
9+
bb0: {
10+
_2 = discriminant(_1);
11+
switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
12+
}
13+
14+
bb1: {
15+
unreachable;
16+
}
17+
18+
bb2: {
19+
_0 = NoFields::A;
20+
goto -> bb4;
21+
}
22+
23+
bb3: {
24+
_0 = NoFields::B;
25+
goto -> bb4;
26+
}
27+
28+
bb4: {
29+
return;
30+
}
31+
}
32+

0 commit comments

Comments
 (0)