Skip to content

Commit a8fb1e3

Browse files
authored
Replace definitions Vec with OnceLock slots (#992)
1 parent 1a966d5 commit a8fb1e3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+793
-690
lines changed

src/definitions.rs

+204-64
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33
/// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar.
44
/// We use DefinitionsBuilder to collect the references / definitions into a single vector
55
/// and then get a definition from a reference using an integer id (just for performance of not using a HashMap)
6-
use std::collections::hash_map::Entry;
6+
use std::{
7+
collections::hash_map::Entry,
8+
fmt::Debug,
9+
sync::{
10+
atomic::{AtomicBool, Ordering},
11+
Arc, OnceLock,
12+
},
13+
};
714

8-
use pyo3::prelude::*;
15+
use pyo3::{prelude::*, PyTraverseError, PyVisit};
916

1017
use ahash::AHashMap;
1118

12-
use crate::build_tools::py_schema_err;
13-
14-
// An integer id for the reference
15-
pub type ReferenceId = usize;
19+
use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};
1620

1721
/// Definitions are validators and serializers that are
1822
/// shared by reference.
@@ -24,91 +28,227 @@ pub type ReferenceId = usize;
2428
/// They get indexed by a ReferenceId, which are integer identifiers
2529
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
2630
/// gets build.
27-
pub type Definitions<T> = [T];
31+
#[derive(Clone)]
32+
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);
2833

29-
#[derive(Clone, Debug)]
30-
struct Definition<T> {
31-
pub id: ReferenceId,
32-
pub value: Option<T>,
34+
impl<T> Definitions<T> {
35+
pub fn values(&self) -> impl Iterator<Item = &Definition<T>> {
36+
self.0.values()
37+
}
38+
}
39+
40+
/// Internal type which contains a definition to be filled
41+
pub struct Definition<T>(Arc<DefinitionInner<T>>);
42+
43+
impl<T> Definition<T> {
44+
pub fn get(&self) -> Option<&T> {
45+
self.0.value.get()
46+
}
47+
}
48+
49+
struct DefinitionInner<T> {
50+
value: OnceLock<T>,
51+
name: LazyName,
52+
}
53+
54+
/// Reference to a definition.
55+
pub struct DefinitionRef<T> {
56+
name: Arc<String>,
57+
value: Definition<T>,
58+
}
59+
60+
// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
61+
impl<T> Clone for DefinitionRef<T> {
62+
fn clone(&self) -> Self {
63+
Self {
64+
name: self.name.clone(),
65+
value: self.value.clone(),
66+
}
67+
}
68+
}
69+
70+
impl<T> DefinitionRef<T> {
71+
pub fn id(&self) -> usize {
72+
Arc::as_ptr(&self.value.0) as usize
73+
}
74+
75+
pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
76+
match self.value.0.value.get() {
77+
Some(value) => self.value.0.name.get_or_init(|| init(value)),
78+
None => "...",
79+
}
80+
}
81+
82+
pub fn get(&self) -> Option<&T> {
83+
self.value.0.value.get()
84+
}
85+
}
86+
87+
impl<T: Debug> Debug for DefinitionRef<T> {
88+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89+
// To avoid possible infinite recursion from recursive definitions,
90+
// a DefinitionRef just displays debug as its name
91+
self.name.fmt(f)
92+
}
93+
}
94+
95+
impl<T: Debug> Debug for Definitions<T> {
96+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97+
// Formatted as a list for backwards compatibility; in principle
98+
// this could be formatted as a map. Maybe change in a future
99+
// minor release of pydantic.
100+
write![f, "["]?;
101+
let mut first = true;
102+
for def in self.0.values() {
103+
write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?;
104+
first = false;
105+
}
106+
write![f, "]"]?;
107+
Ok(())
108+
}
109+
}
110+
111+
impl<T> Clone for Definition<T> {
112+
fn clone(&self) -> Self {
113+
Self(self.0.clone())
114+
}
115+
}
116+
117+
impl<T: Debug> Debug for Definition<T> {
118+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119+
match self.0.value.get() {
120+
Some(value) => value.fmt(f),
121+
None => "...".fmt(f),
122+
}
123+
}
124+
}
125+
126+
impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
127+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
128+
if let Some(value) = self.value.0.value.get() {
129+
value.py_gc_traverse(visit)?;
130+
}
131+
Ok(())
132+
}
133+
}
134+
135+
impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
136+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
137+
for value in self.0.values() {
138+
if let Some(value) = value.0.value.get() {
139+
value.py_gc_traverse(visit)?;
140+
}
141+
}
142+
Ok(())
143+
}
33144
}
34145

35146
#[derive(Clone, Debug)]
36147
pub struct DefinitionsBuilder<T> {
37-
definitions: AHashMap<String, Definition<T>>,
148+
definitions: Definitions<T>,
38149
}
39150

40-
impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
151+
impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
41152
pub fn new() -> Self {
42153
Self {
43-
definitions: AHashMap::new(),
154+
definitions: Definitions(AHashMap::new()),
44155
}
45156
}
46157

47158
/// Get a ReferenceId for the given reference string.
48-
// This ReferenceId can later be used to retrieve a definition
49-
pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId {
50-
let next_id = self.definitions.len();
159+
pub fn get_definition(&mut self, reference: &str) -> DefinitionRef<T> {
51160
// We either need a String copy or two hashmap lookups
52161
// Neither is better than the other
53162
// We opted for the easier outward facing API
54-
match self.definitions.entry(reference.to_string()) {
55-
Entry::Occupied(entry) => entry.get().id,
56-
Entry::Vacant(entry) => {
57-
entry.insert(Definition {
58-
id: next_id,
59-
value: None,
60-
});
61-
next_id
62-
}
163+
let name = Arc::new(reference.to_string());
164+
let value = match self.definitions.0.entry(name.clone()) {
165+
Entry::Occupied(entry) => entry.into_mut(),
166+
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
167+
value: OnceLock::new(),
168+
name: LazyName::new(),
169+
}))),
170+
};
171+
DefinitionRef {
172+
name,
173+
value: value.clone(),
63174
}
64175
}
65176

66177
/// Add a definition, returning the ReferenceId that maps to it
67-
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<ReferenceId> {
68-
let next_id = self.definitions.len();
69-
match self.definitions.entry(reference.clone()) {
70-
Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) {
71-
Some(_) => py_schema_err!("Duplicate ref: `{}`", reference),
72-
None => Ok(entry.get().id),
73-
},
74-
Entry::Vacant(entry) => {
75-
entry.insert(Definition {
76-
id: next_id,
77-
value: Some(value),
78-
});
79-
Ok(next_id)
178+
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
179+
let name = Arc::new(reference);
180+
let value = match self.definitions.0.entry(name.clone()) {
181+
Entry::Occupied(entry) => {
182+
let definition = entry.into_mut();
183+
match definition.0.value.set(value) {
184+
Ok(()) => definition.clone(),
185+
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
186+
}
187+
}
188+
Entry::Vacant(entry) => entry
189+
.insert(Definition(Arc::new(DefinitionInner {
190+
value: OnceLock::from(value),
191+
name: LazyName::new(),
192+
})))
193+
.clone(),
194+
};
195+
Ok(DefinitionRef { name, value })
196+
}
197+
198+
/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
199+
pub fn finish(self) -> PyResult<Definitions<T>> {
200+
for (reference, def) in &self.definitions.0 {
201+
if def.0.value.get().is_none() {
202+
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
80203
}
81204
}
205+
Ok(self.definitions)
82206
}
207+
}
83208

84-
/// Retrieve an item definition using a ReferenceId
85-
/// If the definition doesn't yet exist (as happens in recursive types) then we create it
86-
/// At the end (in finish()) we check that there are no undefined definitions
87-
pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> {
88-
let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) {
89-
Some(v) => v,
90-
None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id),
91-
};
92-
match def.value.as_ref() {
93-
Some(v) => Ok(v),
94-
None => py_schema_err!(
95-
"Definitions error: attempted to use `{}` before it was filled",
96-
reference
97-
),
209+
struct LazyName {
210+
initialized: OnceLock<String>,
211+
in_recursion: AtomicBool,
212+
}
213+
214+
impl LazyName {
215+
fn new() -> Self {
216+
Self {
217+
initialized: OnceLock::new(),
218+
in_recursion: AtomicBool::new(false),
98219
}
99220
}
100221

101-
/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
102-
pub fn finish(self) -> PyResult<Vec<T>> {
103-
// We need to create a vec of defs according to the order in their ids
104-
let mut defs: Vec<(usize, T)> = Vec::new();
105-
for (reference, def) in self.definitions {
106-
match def.value {
107-
None => return py_schema_err!("Definitions error: definition {} was never filled", reference),
108-
Some(v) => defs.push((def.id, v)),
109-
}
222+
/// Gets the validator name, returning the default in the case of recursion loops
223+
fn get_or_init(&self, init: impl FnOnce() -> String) -> &str {
224+
if let Some(s) = self.initialized.get() {
225+
return s.as_str();
226+
}
227+
228+
if self
229+
.in_recursion
230+
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
231+
.is_err()
232+
{
233+
return "...";
234+
}
235+
let result = self.initialized.get_or_init(init).as_str();
236+
self.in_recursion.store(false, Ordering::SeqCst);
237+
result
238+
}
239+
}
240+
241+
impl Debug for LazyName {
242+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243+
self.initialized.get().map_or("...", String::as_str).fmt(f)
244+
}
245+
}
246+
247+
impl Clone for LazyName {
248+
fn clone(&self) -> Self {
249+
Self {
250+
initialized: OnceLock::new(),
251+
in_recursion: AtomicBool::new(false),
110252
}
111-
defs.sort_by_key(|(id, _)| *id);
112-
Ok(defs.into_iter().map(|(_, v)| v).collect())
113253
}
114254
}

src/py_gc.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use ahash::AHashMap;
24
use enum_dispatch::enum_dispatch;
35
use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit};
@@ -35,6 +37,12 @@ impl<T: PyGcTraverse> PyGcTraverse for AHashMap<String, T> {
3537
}
3638
}
3739

40+
impl<T: PyGcTraverse> PyGcTraverse for Arc<T> {
41+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
42+
T::py_gc_traverse(self, visit)
43+
}
44+
}
45+
3846
impl<T: PyGcTraverse> PyGcTraverse for Box<T> {
3947
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
4048
T::py_gc_traverse(self, visit)

src/serializers/extra.rs

-9
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ use serde::ser::Error;
1010
use super::config::SerializationConfig;
1111
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
1212
use super::ob_type::ObTypeLookup;
13-
use super::shared::CombinedSerializer;
14-
use crate::definitions::Definitions;
1513
use crate::recursion_guard::RecursionGuard;
1614

1715
/// this is ugly, would be much better if extra could be stored in `SerializationState`
@@ -48,7 +46,6 @@ impl SerializationState {
4846
Extra::new(
4947
py,
5048
mode,
51-
&[],
5249
by_alias,
5350
&self.warnings,
5451
false,
@@ -72,7 +69,6 @@ impl SerializationState {
7269
#[cfg_attr(debug_assertions, derive(Debug))]
7370
pub(crate) struct Extra<'a> {
7471
pub mode: &'a SerMode,
75-
pub definitions: &'a Definitions<CombinedSerializer>,
7672
pub ob_type_lookup: &'a ObTypeLookup,
7773
pub warnings: &'a CollectWarnings,
7874
pub by_alias: bool,
@@ -98,7 +94,6 @@ impl<'a> Extra<'a> {
9894
pub fn new(
9995
py: Python<'a>,
10096
mode: &'a SerMode,
101-
definitions: &'a Definitions<CombinedSerializer>,
10297
by_alias: bool,
10398
warnings: &'a CollectWarnings,
10499
exclude_unset: bool,
@@ -112,7 +107,6 @@ impl<'a> Extra<'a> {
112107
) -> Self {
113108
Self {
114109
mode,
115-
definitions,
116110
ob_type_lookup: ObTypeLookup::cached(py),
117111
warnings,
118112
by_alias,
@@ -156,7 +150,6 @@ impl SerCheck {
156150
#[cfg_attr(debug_assertions, derive(Debug))]
157151
pub(crate) struct ExtraOwned {
158152
mode: SerMode,
159-
definitions: Vec<CombinedSerializer>,
160153
warnings: CollectWarnings,
161154
by_alias: bool,
162155
exclude_unset: bool,
@@ -176,7 +169,6 @@ impl ExtraOwned {
176169
pub fn new(extra: &Extra) -> Self {
177170
Self {
178171
mode: extra.mode.clone(),
179-
definitions: extra.definitions.to_vec(),
180172
warnings: extra.warnings.clone(),
181173
by_alias: extra.by_alias,
182174
exclude_unset: extra.exclude_unset,
@@ -196,7 +188,6 @@ impl ExtraOwned {
196188
pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> {
197189
Extra {
198190
mode: &self.mode,
199-
definitions: &self.definitions,
200191
ob_type_lookup: ObTypeLookup::cached(py),
201192
warnings: &self.warnings,
202193
by_alias: self.by_alias,

0 commit comments

Comments
 (0)