Skip to content

Commit 7bc4d25

Browse files
committed
Implement FusedIterator for gen block
1 parent 03994e4 commit 7bc4d25

File tree

13 files changed

+154
-1
lines changed

13 files changed

+154
-1
lines changed

compiler/rustc_hir/src/lang_items.rs

+1
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ language_item_table! {
214214
FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None;
215215

216216
Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0);
217+
FusedIterator, sym::fused_iterator, fused_iterator_trait, Target::Trait, GenericRequirement::Exact(0);
217218
Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0);
218219
AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0);
219220

compiler/rustc_middle/src/traits/select.rs

+4
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ pub enum SelectionCandidate<'tcx> {
156156
/// generated for a `gen` construct.
157157
IteratorCandidate,
158158

159+
/// Implementation of an `FusedIterator` trait by one of the coroutine types
160+
/// generated for a `gen` construct.
161+
FusedIteratorCandidate,
162+
159163
/// Implementation of an `AsyncIterator` trait by one of the coroutine types
160164
/// generated for a `async gen` construct.
161165
AsyncIteratorCandidate,

compiler/rustc_middle/src/ty/instance.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,9 @@ impl<'tcx> Instance<'tcx> {
624624
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)
625625
);
626626
hir::LangItem::FuturePoll
627-
} else if Some(trait_id) == lang_items.iterator_trait() {
627+
} else if Some(trait_id) == lang_items.iterator_trait()
628+
|| Some(trait_id) == lang_items.fused_iterator_trait()
629+
{
628630
assert_matches!(
629631
coroutine_kind,
630632
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)

compiler/rustc_span/src/symbol.rs

+2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ symbols! {
207207
FromResidual,
208208
FsOpenOptions,
209209
FsPermissions,
210+
FusedIterator,
210211
Future,
211212
FutureOutput,
212213
GlobalAlloc,
@@ -885,6 +886,7 @@ symbols! {
885886
fsub_algebraic,
886887
fsub_fast,
887888
fundamental,
889+
fused_iterator,
888890
future,
889891
future_trait,
890892
gdb_script_file,

compiler/rustc_trait_selection/src/solve/assembly/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ pub(super) trait GoalKind<'tcx>:
215215
goal: Goal<'tcx, Self>,
216216
) -> QueryResult<'tcx>;
217217

218+
/// A coroutine (that comes from a `gen` desugaring) is known to implement
219+
/// `FusedIterator<Item = O>`, where `O` is given by the generator's yield type
220+
/// that was computed during type-checking.
221+
fn consider_builtin_fused_iterator_candidate(
222+
ecx: &mut EvalCtxt<'_, 'tcx>,
223+
goal: Goal<'tcx, Self>,
224+
) -> QueryResult<'tcx>;
225+
218226
fn consider_builtin_async_iterator_candidate(
219227
ecx: &mut EvalCtxt<'_, 'tcx>,
220228
goal: Goal<'tcx, Self>,
@@ -497,6 +505,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
497505
G::consider_builtin_future_candidate(self, goal)
498506
} else if lang_items.iterator_trait() == Some(trait_def_id) {
499507
G::consider_builtin_iterator_candidate(self, goal)
508+
} else if lang_items.fused_iterator_trait() == Some(trait_def_id) {
509+
G::consider_builtin_fused_iterator_candidate(self, goal)
500510
} else if lang_items.async_iterator_trait() == Some(trait_def_id) {
501511
G::consider_builtin_async_iterator_candidate(self, goal)
502512
} else if lang_items.coroutine_trait() == Some(trait_def_id) {

compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs

+35
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,41 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> {
647647
)
648648
}
649649

650+
fn consider_builtin_fused_iterator_candidate(
651+
ecx: &mut EvalCtxt<'_, 'tcx>,
652+
goal: Goal<'tcx, Self>,
653+
) -> QueryResult<'tcx> {
654+
let self_ty = goal.predicate.self_ty();
655+
let ty::Coroutine(def_id, args) = *self_ty.kind() else {
656+
return Err(NoSolution);
657+
};
658+
659+
// Coroutines are not Iterators unless they come from `gen` desugaring
660+
let tcx = ecx.tcx();
661+
if !tcx.coroutine_is_gen(def_id) {
662+
return Err(NoSolution);
663+
}
664+
665+
let Some(iterator_trait) = tcx.lang_items().iterator_trait() else {
666+
return Err(NoSolution);
667+
};
668+
669+
let term = args.as_coroutine().yield_ty().into();
670+
671+
Self::consider_implied_clause(
672+
ecx,
673+
goal,
674+
ty::ProjectionPredicate {
675+
projection_ty: ty::AliasTy::new(ecx.tcx(), iterator_trait, [self_ty]),
676+
term,
677+
}
678+
.to_predicate(tcx),
679+
// Technically, we need to check that the iterator type is Sized,
680+
// but that's already proven by the generator being WF.
681+
[],
682+
)
683+
}
684+
650685
fn consider_builtin_async_iterator_candidate(
651686
ecx: &mut EvalCtxt<'_, 'tcx>,
652687
goal: Goal<'tcx, Self>,

compiler/rustc_trait_selection/src/solve/trait_goals.rs

+7
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,13 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
456456
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
457457
}
458458

459+
fn consider_builtin_fused_iterator_candidate(
460+
ecx: &mut EvalCtxt<'_, 'tcx>,
461+
goal: Goal<'tcx, Self>,
462+
) -> QueryResult<'tcx> {
463+
Self::consider_builtin_iterator_candidate(ecx, goal)
464+
}
465+
459466
fn consider_builtin_async_iterator_candidate(
460467
ecx: &mut EvalCtxt<'_, 'tcx>,
461468
goal: Goal<'tcx, Self>,

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

+19
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
118118
self.assemble_future_candidates(obligation, &mut candidates);
119119
} else if lang_items.iterator_trait() == Some(def_id) {
120120
self.assemble_iterator_candidates(obligation, &mut candidates);
121+
} else if lang_items.fused_iterator_trait() == Some(def_id) {
122+
self.assemble_fused_iterator_candidates(obligation, &mut candidates);
121123
} else if lang_items.async_iterator_trait() == Some(def_id) {
122124
self.assemble_async_iterator_candidates(obligation, &mut candidates);
123125
} else if lang_items.async_fn_kind_helper() == Some(def_id) {
@@ -313,6 +315,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
313315
}
314316
}
315317

318+
fn assemble_fused_iterator_candidates(
319+
&mut self,
320+
obligation: &PolyTraitObligation<'tcx>,
321+
candidates: &mut SelectionCandidateSet<'tcx>,
322+
) {
323+
let self_ty = obligation.self_ty().skip_binder();
324+
if let ty::Coroutine(did, ..) = self_ty.kind() {
325+
// gen constructs get lowered to a special kind of coroutine that
326+
// should directly `impl FusedIterator`.
327+
if self.tcx().coroutine_is_gen(*did) {
328+
debug!(?self_ty, ?obligation, "assemble_fused_iterator_candidates",);
329+
330+
candidates.vec.push(FusedIteratorCandidate);
331+
}
332+
}
333+
}
334+
316335
fn assemble_async_iterator_candidates(
317336
&mut self,
318337
obligation: &PolyTraitObligation<'tcx>,

compiler/rustc_trait_selection/src/traits/select/confirmation.rs

+34
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
107107
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
108108
}
109109

110+
FusedIteratorCandidate => {
111+
let vtable_iterator = self.confirm_fused_iterator_candidate(obligation)?;
112+
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
113+
}
114+
110115
AsyncIteratorCandidate => {
111116
let vtable_iterator = self.confirm_async_iterator_candidate(obligation)?;
112117
ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator)
@@ -838,6 +843,35 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
838843
Ok(nested)
839844
}
840845

846+
fn confirm_fused_iterator_candidate(
847+
&mut self,
848+
obligation: &PolyTraitObligation<'tcx>,
849+
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
850+
// Okay to skip binder because the args on coroutine types never
851+
// touch bound regions, they just capture the in-scope
852+
// type/region parameters.
853+
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
854+
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
855+
bug!("closure candidate for non-closure {:?}", obligation);
856+
};
857+
858+
debug!(?obligation, ?coroutine_def_id, ?args, "confirm_fused_iterator_candidate");
859+
860+
let gen_sig = args.as_coroutine().sig();
861+
862+
let (trait_ref, _) = super::util::fused_iterator_trait_ref_and_outputs(
863+
self.tcx(),
864+
obligation.predicate.def_id(),
865+
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
866+
gen_sig,
867+
);
868+
869+
let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
870+
debug!(?trait_ref, ?nested, "fused iterator candidate obligations");
871+
872+
Ok(nested)
873+
}
874+
841875
fn confirm_async_iterator_candidate(
842876
&mut self,
843877
obligation: &PolyTraitObligation<'tcx>,

compiler/rustc_trait_selection/src/traits/select/mod.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
18551855
| CoroutineCandidate
18561856
| FutureCandidate
18571857
| IteratorCandidate
1858+
| FusedIteratorCandidate
18581859
| AsyncIteratorCandidate
18591860
| FnPointerCandidate { .. }
18601861
| BuiltinObjectCandidate
@@ -1887,6 +1888,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
18871888
| CoroutineCandidate
18881889
| FutureCandidate
18891890
| IteratorCandidate
1891+
| FusedIteratorCandidate
18901892
| AsyncIteratorCandidate
18911893
| FnPointerCandidate { .. }
18921894
| BuiltinObjectCandidate
@@ -1925,6 +1927,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
19251927
| CoroutineCandidate
19261928
| FutureCandidate
19271929
| IteratorCandidate
1930+
| FusedIteratorCandidate
19281931
| AsyncIteratorCandidate
19291932
| FnPointerCandidate { .. }
19301933
| BuiltinObjectCandidate
@@ -1943,6 +1946,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
19431946
| CoroutineCandidate
19441947
| FutureCandidate
19451948
| IteratorCandidate
1949+
| FusedIteratorCandidate
19461950
| AsyncIteratorCandidate
19471951
| FnPointerCandidate { .. }
19481952
| BuiltinObjectCandidate
@@ -2053,6 +2057,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
20532057
| CoroutineCandidate
20542058
| FutureCandidate
20552059
| IteratorCandidate
2060+
| FusedIteratorCandidate
20562061
| AsyncIteratorCandidate
20572062
| FnPointerCandidate { .. }
20582063
| BuiltinObjectCandidate
@@ -2067,6 +2072,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
20672072
| CoroutineCandidate
20682073
| FutureCandidate
20692074
| IteratorCandidate
2075+
| FusedIteratorCandidate
20702076
| AsyncIteratorCandidate
20712077
| FnPointerCandidate { .. }
20722078
| BuiltinObjectCandidate

compiler/rustc_trait_selection/src/traits/util.rs

+11
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,17 @@ pub fn iterator_trait_ref_and_outputs<'tcx>(
323323
(trait_ref, sig.yield_ty)
324324
}
325325

326+
pub fn fused_iterator_trait_ref_and_outputs<'tcx>(
327+
tcx: TyCtxt<'tcx>,
328+
fused_iterator_def_id: DefId,
329+
self_ty: Ty<'tcx>,
330+
sig: ty::GenSig<'tcx>,
331+
) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
332+
assert!(!self_ty.has_escaping_bound_vars());
333+
let trait_ref = ty::TraitRef::new(tcx, fused_iterator_def_id, [self_ty]);
334+
(trait_ref, sig.yield_ty)
335+
}
336+
326337
pub fn async_iterator_trait_ref_and_outputs<'tcx>(
327338
tcx: TyCtxt<'tcx>,
328339
async_iterator_def_id: DefId,

library/core/src/iter/traits/marker.rs

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub unsafe trait TrustedFused {}
2828
#[rustc_unsafe_specialization_marker]
2929
// FIXME: this should be a #[marker] and have another blanket impl for T: TrustedFused
3030
// but that ICEs iter::Fuse specializations.
31+
#[cfg_attr(not(bootstrap), lang = "fused_iterator")]
3132
pub trait FusedIterator: Iterator {}
3233

3334
#[stable(feature = "fused", since = "1.26.0")]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//@ revisions: next old
2+
//@compile-flags: --edition 2024 -Zunstable-options
3+
//@[next] compile-flags: -Znext-solver
4+
//@ check-pass
5+
#![feature(gen_blocks)]
6+
7+
use std::iter::FusedIterator;
8+
9+
fn foo() -> impl FusedIterator {
10+
gen { yield 42 }
11+
}
12+
13+
fn bar() -> impl FusedIterator<Item = u16> {
14+
gen { yield 42 }
15+
}
16+
17+
fn baz() -> impl FusedIterator + Iterator<Item = i64> {
18+
gen { yield 42 }
19+
}
20+
21+
fn main() {}

0 commit comments

Comments
 (0)