@@ -23,7 +23,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok
23
23
let item_name = & derive_item. ident ;
24
24
25
25
// Add a bound `T: Arbitrary` to every type parameter T.
26
- let generics = add_trait_bound ( derive_item. generics ) ;
26
+ let generics = add_trait_bound_arbitrary ( derive_item. generics ) ;
27
27
// Generate an expression to sum up the heap size of each field.
28
28
let ( impl_generics, ty_generics, where_clause) = generics. split_for_impl ( ) ;
29
29
@@ -40,7 +40,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok
40
40
}
41
41
42
42
/// Add a bound `T: Arbitrary` to every type parameter T.
43
- fn add_trait_bound ( mut generics : Generics ) -> Generics {
43
+ fn add_trait_bound_arbitrary ( mut generics : Generics ) -> Generics {
44
44
generics. params . iter_mut ( ) . for_each ( |param| {
45
45
if let GenericParam :: Type ( type_param) = param {
46
46
type_param. bounds . push ( parse_quote ! ( kani:: Arbitrary ) ) ;
@@ -165,3 +165,93 @@ fn fn_any_enum(ident: &Ident, data: &DataEnum) -> TokenStream {
165
165
}
166
166
}
167
167
}
168
+
169
+ pub fn expand_derive_invariant ( item : proc_macro:: TokenStream ) -> proc_macro:: TokenStream {
170
+ let derive_item = parse_macro_input ! ( item as DeriveInput ) ;
171
+ let item_name = & derive_item. ident ;
172
+
173
+ // Add a bound `T: Invariant` to every type parameter T.
174
+ let generics = add_trait_bound_invariant ( derive_item. generics ) ;
175
+ // Generate an expression to sum up the heap size of each field.
176
+ let ( impl_generics, ty_generics, where_clause) = generics. split_for_impl ( ) ;
177
+
178
+ let body = is_safe_body ( & item_name, & derive_item. data ) ;
179
+ let expanded = quote ! {
180
+ // The generated implementation.
181
+ impl #impl_generics kani:: Invariant for #item_name #ty_generics #where_clause {
182
+ fn is_safe( & self ) -> bool {
183
+ #body
184
+ }
185
+ }
186
+ } ;
187
+ proc_macro:: TokenStream :: from ( expanded)
188
+ }
189
+
190
+ /// Add a bound `T: Invariant` to every type parameter T.
191
+ fn add_trait_bound_invariant ( mut generics : Generics ) -> Generics {
192
+ generics. params . iter_mut ( ) . for_each ( |param| {
193
+ if let GenericParam :: Type ( type_param) = param {
194
+ type_param. bounds . push ( parse_quote ! ( kani:: Invariant ) ) ;
195
+ }
196
+ } ) ;
197
+ generics
198
+ }
199
+
200
+ fn is_safe_body ( ident : & Ident , data : & Data ) -> TokenStream {
201
+ match data {
202
+ Data :: Struct ( struct_data) => struct_safe_conjunction ( ident, & struct_data. fields ) ,
203
+ Data :: Enum ( _) => {
204
+ abort ! ( Span :: call_site( ) , "Cannot derive `Invariant` for `{}` enum" , ident;
205
+ note = ident. span( ) =>
206
+ "`#[derive(Invariant)]` cannot be used for enums such as `{}`" , ident
207
+ )
208
+ }
209
+ Data :: Union ( _) => {
210
+ abort ! ( Span :: call_site( ) , "Cannot derive `Invariant` for `{}` union" , ident;
211
+ note = ident. span( ) =>
212
+ "`#[derive(Invariant)]` cannot be used for unions such as `{}`" , ident
213
+ )
214
+ }
215
+ }
216
+ }
217
+
218
+ /// Generates an expression that is the conjunction of `is_safe` calls for each field in the struct.
219
+ fn struct_safe_conjunction ( _ident : & Ident , fields : & Fields ) -> TokenStream {
220
+ match fields {
221
+ // Expands to the expression
222
+ // `true && self.field1.is_safe() && self.field2.is_safe() && ..`
223
+ Fields :: Named ( ref fields) => {
224
+ let safe_calls = fields. named . iter ( ) . map ( |field| {
225
+ let name = & field. ident ;
226
+ quote_spanned ! { field. span( ) =>
227
+ self . #name. is_safe( )
228
+ }
229
+ } ) ;
230
+ // An initial value is required for empty structs
231
+ safe_calls. fold ( quote ! { true } , |acc, call| {
232
+ quote ! { #acc && #call }
233
+ } )
234
+ }
235
+ Fields :: Unnamed ( ref fields) => {
236
+ // Expands to the expression
237
+ // `true && self.0.is_safe() && self.1.is_safe() && ..`
238
+ let safe_calls = fields. unnamed . iter ( ) . enumerate ( ) . map ( |( i, field) | {
239
+ let idx = syn:: Index :: from ( i) ;
240
+ quote_spanned ! { field. span( ) =>
241
+ self . #idx. is_safe( )
242
+ }
243
+ } ) ;
244
+ // An initial value is required for empty structs
245
+ safe_calls. fold ( quote ! { true } , |acc, call| {
246
+ quote ! { #acc && #call }
247
+ } )
248
+ }
249
+ // Expands to the expression
250
+ // `true`
251
+ Fields :: Unit => {
252
+ quote ! {
253
+ true
254
+ }
255
+ }
256
+ }
257
+ }
0 commit comments