Skip to content

Commit 5c7cd63

Browse files
authored
Add a #[derive(Invariant)] macro (rust-lang#3250)
This PR adds a `#[derive(Invariant)]` macro for structs which allows users to automatically derive the `Invariant` implementations for any struct. The derived implementation determines the invariant for the struct as the conjunction of invariants of its fields. In other words, the invariant is derived as `true && self.field1.is_safe() && self.field2.is_safe() && ..`. For example, for the struct ```rs #[derive(kani::Invariant)] struct Point<X, Y> { x: X, y: Y, } ``` we derive the `Invariant` implementation as ```rs impl<X: kani::Invariant, Y: kani::Invariant> kani::Invariant for Point<X, Y> { fn is_safe(&self) -> bool { true && self.x.is_safe() && self.y.is_safe() } } ``` Related rust-lang#3095
1 parent 7dad847 commit 5c7cd63

File tree

12 files changed

+244
-2
lines changed

12 files changed

+244
-2
lines changed

library/kani_macros/src/derive.rs

+92-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok
2323
let item_name = &derive_item.ident;
2424

2525
// Add a bound `T: Arbitrary` to every type parameter T.
26-
let generics = add_trait_bound(derive_item.generics);
26+
let generics = add_trait_bound_arbitrary(derive_item.generics);
2727
// Generate an expression to sum up the heap size of each field.
2828
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
2929

@@ -40,7 +40,7 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok
4040
}
4141

4242
/// Add a bound `T: Arbitrary` to every type parameter T.
43-
fn add_trait_bound(mut generics: Generics) -> Generics {
43+
fn add_trait_bound_arbitrary(mut generics: Generics) -> Generics {
4444
generics.params.iter_mut().for_each(|param| {
4545
if let GenericParam::Type(type_param) = param {
4646
type_param.bounds.push(parse_quote!(kani::Arbitrary));
@@ -165,3 +165,93 @@ fn fn_any_enum(ident: &Ident, data: &DataEnum) -> TokenStream {
165165
}
166166
}
167167
}
168+
169+
pub fn expand_derive_invariant(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
170+
let derive_item = parse_macro_input!(item as DeriveInput);
171+
let item_name = &derive_item.ident;
172+
173+
// Add a bound `T: Invariant` to every type parameter T.
174+
let generics = add_trait_bound_invariant(derive_item.generics);
175+
// Generate an expression to sum up the heap size of each field.
176+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
177+
178+
let body = is_safe_body(&item_name, &derive_item.data);
179+
let expanded = quote! {
180+
// The generated implementation.
181+
impl #impl_generics kani::Invariant for #item_name #ty_generics #where_clause {
182+
fn is_safe(&self) -> bool {
183+
#body
184+
}
185+
}
186+
};
187+
proc_macro::TokenStream::from(expanded)
188+
}
189+
190+
/// Add a bound `T: Invariant` to every type parameter T.
191+
fn add_trait_bound_invariant(mut generics: Generics) -> Generics {
192+
generics.params.iter_mut().for_each(|param| {
193+
if let GenericParam::Type(type_param) = param {
194+
type_param.bounds.push(parse_quote!(kani::Invariant));
195+
}
196+
});
197+
generics
198+
}
199+
200+
fn is_safe_body(ident: &Ident, data: &Data) -> TokenStream {
201+
match data {
202+
Data::Struct(struct_data) => struct_safe_conjunction(ident, &struct_data.fields),
203+
Data::Enum(_) => {
204+
abort!(Span::call_site(), "Cannot derive `Invariant` for `{}` enum", ident;
205+
note = ident.span() =>
206+
"`#[derive(Invariant)]` cannot be used for enums such as `{}`", ident
207+
)
208+
}
209+
Data::Union(_) => {
210+
abort!(Span::call_site(), "Cannot derive `Invariant` for `{}` union", ident;
211+
note = ident.span() =>
212+
"`#[derive(Invariant)]` cannot be used for unions such as `{}`", ident
213+
)
214+
}
215+
}
216+
}
217+
218+
/// Generates an expression that is the conjunction of `is_safe` calls for each field in the struct.
219+
fn struct_safe_conjunction(_ident: &Ident, fields: &Fields) -> TokenStream {
220+
match fields {
221+
// Expands to the expression
222+
// `true && self.field1.is_safe() && self.field2.is_safe() && ..`
223+
Fields::Named(ref fields) => {
224+
let safe_calls = fields.named.iter().map(|field| {
225+
let name = &field.ident;
226+
quote_spanned! {field.span()=>
227+
self.#name.is_safe()
228+
}
229+
});
230+
// An initial value is required for empty structs
231+
safe_calls.fold(quote! { true }, |acc, call| {
232+
quote! { #acc && #call }
233+
})
234+
}
235+
Fields::Unnamed(ref fields) => {
236+
// Expands to the expression
237+
// `true && self.0.is_safe() && self.1.is_safe() && ..`
238+
let safe_calls = fields.unnamed.iter().enumerate().map(|(i, field)| {
239+
let idx = syn::Index::from(i);
240+
quote_spanned! {field.span()=>
241+
self.#idx.is_safe()
242+
}
243+
});
244+
// An initial value is required for empty structs
245+
safe_calls.fold(quote! { true }, |acc, call| {
246+
quote! { #acc && #call }
247+
})
248+
}
249+
// Expands to the expression
250+
// `true`
251+
Fields::Unit => {
252+
quote! {
253+
true
254+
}
255+
}
256+
}
257+
}

library/kani_macros/src/lib.rs

+7
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ pub fn derive_arbitrary(item: TokenStream) -> TokenStream {
107107
derive::expand_derive_arbitrary(item)
108108
}
109109

110+
/// Allow users to auto generate Invariant implementations by using `#[derive(Invariant)]` macro.
111+
#[proc_macro_error]
112+
#[proc_macro_derive(Invariant)]
113+
pub fn derive_invariant(item: TokenStream) -> TokenStream {
114+
derive::expand_derive_invariant(item)
115+
}
116+
110117
/// Add a precondition to this function.
111118
///
112119
/// This is part of the function contract API, for more general information see
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Check that Kani can automatically derive `Invariant` for empty structs.
5+
6+
extern crate kani;
7+
use kani::Invariant;
8+
9+
#[derive(kani::Arbitrary)]
10+
#[derive(kani::Invariant)]
11+
struct Void;
12+
13+
#[derive(kani::Arbitrary)]
14+
#[derive(kani::Invariant)]
15+
struct Void2(());
16+
17+
#[derive(kani::Arbitrary)]
18+
#[derive(kani::Invariant)]
19+
struct VoidOfVoid(Void, Void2);
20+
21+
#[kani::proof]
22+
fn check_empty_struct_invariant_1() {
23+
let void1: Void = kani::any();
24+
assert!(void1.is_safe());
25+
}
26+
27+
#[kani::proof]
28+
fn check_empty_struct_invariant_2() {
29+
let void2: Void2 = kani::any();
30+
assert!(void2.is_safe());
31+
}
32+
33+
#[kani::proof]
34+
fn check_empty_struct_invariant_3() {
35+
let void3: VoidOfVoid = kani::any();
36+
assert!(void3.is_safe());
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- Status: SUCCESS\
2+
- Description: "assertion failed: void1.is_safe()"
3+
4+
- Status: SUCCESS\
5+
- Description: "assertion failed: void2.is_safe()"
6+
7+
- Status: SUCCESS\
8+
- Description: "assertion failed: void3.is_safe()"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Status: SUCCESS\
2+
- Description: "assertion failed: point.is_safe()"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Check that Kani can automatically derive `Invariant` for structs with generics.
5+
6+
extern crate kani;
7+
use kani::Invariant;
8+
9+
#[derive(kani::Arbitrary)]
10+
#[derive(kani::Invariant)]
11+
struct Point<X, Y> {
12+
x: X,
13+
y: Y,
14+
}
15+
16+
#[kani::proof]
17+
fn check_generic_struct_invariant() {
18+
let point: Point<i32, i8> = kani::any();
19+
assert!(point.is_safe());
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- Status: FAILURE\
2+
- Description: "assertion failed: wrapper.is_safe()"
3+
4+
Verification failed for - check_invariant_fail
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Check that a verification failure is triggered when the derived `Invariant`
5+
//! method is checked but not satisfied.
6+
7+
extern crate kani;
8+
use kani::Invariant;
9+
// Note: This represents an incorrect usage of `Arbitrary` and `Invariant`.
10+
//
11+
// The `Arbitrary` implementation should respect the type invariant,
12+
// but Kani does not enforce this in any way at the moment.
13+
// <https://github.com/model-checking/kani/issues/3265>
14+
#[derive(kani::Arbitrary)]
15+
struct NotNegative(i32);
16+
17+
impl kani::Invariant for NotNegative {
18+
fn is_safe(&self) -> bool {
19+
self.0 >= 0
20+
}
21+
}
22+
23+
#[derive(kani::Arbitrary)]
24+
#[derive(kani::Invariant)]
25+
struct NotNegativeWrapper {
26+
x: NotNegative,
27+
}
28+
29+
#[kani::proof]
30+
fn check_invariant_fail() {
31+
let wrapper: NotNegativeWrapper = kani::any();
32+
assert!(wrapper.is_safe());
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Status: SUCCESS\
2+
- Description: "assertion failed: point.is_safe()"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Check that Kani can automatically derive `Invariant` for structs with named fields.
5+
6+
extern crate kani;
7+
use kani::Invariant;
8+
9+
#[derive(kani::Arbitrary)]
10+
#[derive(kani::Invariant)]
11+
struct Point {
12+
x: i32,
13+
y: i32,
14+
}
15+
16+
#[kani::proof]
17+
fn check_generic_struct_invariant() {
18+
let point: Point = kani::any();
19+
assert!(point.is_safe());
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Status: SUCCESS\
2+
- Description: "assertion failed: point.is_safe()"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Check that Kani can automatically derive `Invariant` for structs with unnamed fields.
5+
6+
extern crate kani;
7+
use kani::Invariant;
8+
9+
#[derive(kani::Arbitrary)]
10+
#[derive(kani::Invariant)]
11+
struct Point(i32, i32);
12+
13+
#[kani::proof]
14+
fn check_generic_struct_invariant() {
15+
let point: Point = kani::any();
16+
assert!(point.is_safe());
17+
}

0 commit comments

Comments
 (0)