Skip to content

Commit 9a0c7f8

Browse files
committed
Fixed enum + union uninit order with #[layout(bound)]
1 parent 22a866a commit 9a0c7f8

File tree

4 files changed

+130
-19
lines changed

4 files changed

+130
-19
lines changed

const-type-layout-derive/src/lib.rs

+103-7
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ pub fn derive_type_layout(input: TokenStream) -> TokenStream {
3636
let Attributes {
3737
reprs,
3838
extra_bounds,
39-
} = parse_attributes(&input.attrs, &mut type_params);
39+
ground,
40+
} = parse_attributes(&input.attrs, &mut type_params, &input.data);
4041

4142
let layout = layout_of_type(&ty_name, &ty_generics, &input.data, &reprs);
42-
let uninit = uninit_for_type(&ty_name, &input.data);
43+
let uninit = uninit_for_type(&ty_name, &input.data, &ground);
4344

4445
let inner_types = extract_inner_types(&input.data);
4546

@@ -95,15 +96,49 @@ pub fn derive_type_layout(input: TokenStream) -> TokenStream {
9596
struct Attributes {
9697
reprs: String,
9798
extra_bounds: Vec<syn::WherePredicate>,
99+
ground: Vec<syn::Ident>,
98100
}
99101

100102
#[allow(clippy::too_many_lines)]
101-
fn parse_attributes(attrs: &[syn::Attribute], type_params: &mut Vec<&syn::Ident>) -> Attributes {
103+
fn parse_attributes(
104+
attrs: &[syn::Attribute],
105+
type_params: &mut Vec<&syn::Ident>,
106+
data: &syn::Data,
107+
) -> Attributes {
102108
// Could parse based on https://github.com/rust-lang/rust/blob/d13e8dd41d44a73664943169d5b7fe39b22c449f/compiler/rustc_attr/src/builtin.rs#L772-L781 instead
103109
let mut reprs = Vec::new();
104110

105111
let mut extra_bounds: Vec<syn::WherePredicate> = Vec::new();
106112

113+
let mut ground = match data {
114+
syn::Data::Struct(_) => Vec::new(),
115+
syn::Data::Enum(syn::DataEnum { variants, .. }) => {
116+
let mut ground = Vec::with_capacity(variants.len());
117+
118+
for variant in variants {
119+
if matches!(variant.fields, syn::Fields::Unit) {
120+
ground.push(variant.ident.clone());
121+
}
122+
}
123+
124+
for variant in variants {
125+
if !matches!(variant.fields, syn::Fields::Unit) {
126+
ground.push(variant.ident.clone());
127+
}
128+
}
129+
130+
ground
131+
},
132+
syn::Data::Union(syn::DataUnion {
133+
fields: syn::FieldsNamed { named: fields, .. },
134+
..
135+
}) => fields
136+
.iter()
137+
.map(|field| field.ident.clone().unwrap())
138+
.collect(),
139+
};
140+
let mut groundier = Vec::with_capacity(ground.len());
141+
107142
for attr in attrs {
108143
if attr.path.is_ident("repr") {
109144
if let Ok(syn::Meta::List(syn::MetaList { nested, .. })) = attr.parse_meta() {
@@ -161,10 +196,63 @@ fn parse_attributes(attrs: &[syn::Attribute], type_params: &mut Vec<&syn::Ident>
161196
err
162197
),
163198
}
199+
} else if path.is_ident("ground") {
200+
match syn::parse_str(&s.value()) {
201+
Ok(g) => match data {
202+
syn::Data::Struct(_) => emit_error!(
203+
path.span(),
204+
"[const-type-layout]: Invalid #[layout(ground)] \
205+
attribute: structs do not have a ground layout."
206+
),
207+
syn::Data::Union(_) | syn::Data::Enum(_) => {
208+
let g: syn::Ident = g;
209+
210+
if let Some(i) = ground.iter().position(|e| e == &g) {
211+
let g = ground.remove(i);
212+
groundier.push(g);
213+
} else if groundier.contains(&g) {
214+
emit_error!(
215+
path.span(),
216+
"[const-type-layout]: Duplicate #[layout(ground = \
217+
\"{}\")] attribute.",
218+
g
219+
);
220+
} else {
221+
emit_error!(
222+
path.span(),
223+
"[const-type-layout]: Invalid #[layout(ground)] \
224+
attribute: \"{}\" is not a {} in this {}.",
225+
g,
226+
match data {
227+
syn::Data::Enum(_) => "variant",
228+
syn::Data::Struct(_) | syn::Data::Union(_) =>
229+
"field",
230+
},
231+
match data {
232+
syn::Data::Enum(_) => "enum",
233+
syn::Data::Struct(_) | syn::Data::Union(_) =>
234+
"union",
235+
},
236+
);
237+
}
238+
},
239+
},
240+
Err(err) => emit_error!(
241+
s.span(),
242+
"[const-type-layout]: Invalid #[layout(ground = \"{}\")] \
243+
attribute: {}.",
244+
match data {
245+
syn::Data::Enum(_) => "variant",
246+
syn::Data::Struct(_) | syn::Data::Union(_) => "field",
247+
},
248+
err
249+
),
250+
}
164251
} else {
165252
emit_error!(
166253
path.span(),
167-
"[const-type-layout]: Unknown attribute, use `bound` or `free`."
254+
"[const-type-layout]: Unknown attribute, use `bound`, `free`, or \
255+
`ground`."
168256
);
169257
}
170258
} else {
@@ -193,9 +281,12 @@ fn parse_attributes(attrs: &[syn::Attribute], type_params: &mut Vec<&syn::Ident>
193281
.intersperse(String::from(","))
194282
.collect::<String>();
195283

284+
groundier.extend(ground);
285+
196286
Attributes {
197287
reprs,
198288
extra_bounds,
289+
ground: groundier,
199290
}
200291
}
201292

@@ -587,7 +678,11 @@ fn quote_discriminant(
587678
}
588679

589680
#[allow(clippy::too_many_lines)]
590-
fn uninit_for_type(ty_name: &syn::Ident, data: &syn::Data) -> proc_macro2::TokenStream {
681+
fn uninit_for_type(
682+
ty_name: &syn::Ident,
683+
data: &syn::Data,
684+
ground: &[syn::Ident],
685+
) -> proc_macro2::TokenStream {
591686
match data {
592687
syn::Data::Struct(data) => {
593688
// Structs are uninhabited if any of their fields in uninhabited
@@ -656,7 +751,7 @@ fn uninit_for_type(ty_name: &syn::Ident, data: &syn::Data) -> proc_macro2::Token
656751
// (2) tuple and struct variants are uninhabited
657752
// if any of their fields are uninhabited
658753

659-
let variant_initialisers = variants.iter().map(|syn::Variant {
754+
let variant_initialisers = ground.iter().filter_map(|g| variants.iter().find(|v| &v.ident == g)).map(|syn::Variant {
660755
ident: variant_name,
661756
fields: variant_fields,
662757
..
@@ -732,8 +827,9 @@ fn uninit_for_type(ty_name: &syn::Ident, data: &syn::Data) -> proc_macro2::Token
732827
}) => {
733828
// Unions are uninhabited if all fields are uninhabited
734829

735-
let (field_names, field_initialisers) = fields
830+
let (field_names, field_initialisers) = ground
736831
.iter()
832+
.filter_map(|g| fields.iter().find(|f| f.ident.as_ref() == Some(g)))
737833
.map(
738834
|syn::Field {
739835
ident: field_name,

src/impls/alloc/boxed.rs

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ unsafe impl<T: ~const TypeLayout> const TypeLayout for alloc::boxed::Box<T> {
1515
};
1616

1717
unsafe fn uninit() -> MaybeUninhabited<core::mem::MaybeUninit<Self>> {
18-
// TODO: Handle infinite recursion case
1918
if let MaybeUninhabited::Uninhabited = <T as TypeLayout>::uninit() {
2019
return MaybeUninhabited::Uninhabited;
2120
}

src/impls/core/ptr.rs

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
2-
impls::leak_uninit_ptr, MaybeUninhabited, Mutability, TypeGraph, TypeLayout, TypeLayoutGraph,
3-
TypeLayoutInfo, TypeStructure,
2+
MaybeUninhabited, Mutability, TypeGraph, TypeLayout, TypeLayoutGraph, TypeLayoutInfo,
3+
TypeStructure,
44
};
55

66
unsafe impl<T: ~const TypeLayout> const TypeLayout for *const T {
@@ -15,7 +15,9 @@ unsafe impl<T: ~const TypeLayout> const TypeLayout for *const T {
1515
};
1616

1717
unsafe fn uninit() -> MaybeUninhabited<core::mem::MaybeUninit<Self>> {
18-
MaybeUninhabited::Inhabited(core::mem::MaybeUninit::new(leak_uninit_ptr()))
18+
MaybeUninhabited::Inhabited(core::mem::MaybeUninit::new(
19+
core::ptr::NonNull::dangling().as_ptr(),
20+
))
1921
}
2022
}
2123

@@ -39,7 +41,9 @@ unsafe impl<T: ~const TypeLayout> const TypeLayout for *mut T {
3941
};
4042

4143
unsafe fn uninit() -> MaybeUninhabited<core::mem::MaybeUninit<Self>> {
42-
MaybeUninhabited::Inhabited(core::mem::MaybeUninit::new(leak_uninit_ptr()))
44+
MaybeUninhabited::Inhabited(core::mem::MaybeUninit::new(
45+
core::ptr::NonNull::dangling().as_ptr(),
46+
))
4347
}
4448
}
4549

@@ -63,9 +67,7 @@ unsafe impl<T: ~const TypeLayout> const TypeLayout for core::ptr::NonNull<T> {
6367
};
6468

6569
unsafe fn uninit() -> MaybeUninhabited<core::mem::MaybeUninit<Self>> {
66-
MaybeUninhabited::Inhabited(core::mem::MaybeUninit::new(
67-
core::ptr::NonNull::new_unchecked(leak_uninit_ptr()),
68-
))
70+
MaybeUninhabited::Inhabited(core::mem::MaybeUninit::new(core::ptr::NonNull::dangling()))
6971
}
7072
}
7173

try-crate/src/main.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ union SingleUnion {
4747
a: u8,
4848
}
4949

50+
#[derive(TypeLayout)]
51+
#[layout(ground = "b")]
52+
union RecursiveRef<'a> {
53+
a: &'a RecursiveRef<'a>,
54+
b: (),
55+
}
56+
57+
#[derive(TypeLayout)]
58+
union RecursivePtr {
59+
a: *const RecursivePtr,
60+
}
61+
5062
#[allow(clippy::empty_enum)]
5163
#[derive(TypeLayout)]
5264
enum Never {}
@@ -89,16 +101,16 @@ enum List<T> {
89101
Tail,
90102
}
91103

92-
// TODO: allow an arbitrary variant order
93104
#[derive(TypeLayout)]
105+
#[layout(ground = "Leaf")]
94106
enum Tree<T> {
95-
Leaf {
96-
item: T,
97-
},
98107
Node {
99108
left: Box<Tree<T>>,
100109
right: Box<Tree<T>>,
101110
},
111+
Leaf {
112+
item: T,
113+
},
102114
}
103115

104116
#[repr(transparent)]
@@ -135,6 +147,8 @@ fn main() {
135147

136148
println!("{:#?}", Bar::TYPE_GRAPH);
137149
println!("{:#?}", SingleUnion::TYPE_GRAPH);
150+
println!("{:#?}", RecursiveRef::<'static>::TYPE_GRAPH);
151+
println!("{:#?}", RecursivePtr::TYPE_GRAPH);
138152

139153
println!("{:#?}", Never::TYPE_GRAPH);
140154
println!("{:#?}", Single::TYPE_GRAPH);

0 commit comments

Comments
 (0)