Skip to content

Commit 705618e

Browse files
committed
fix repr breakages
1 parent 0d1b06a commit 705618e

File tree

5 files changed

+202
-25
lines changed

5 files changed

+202
-25
lines changed

src/definitions.rs

+90-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
use std::{
77
collections::hash_map::Entry,
88
fmt::Debug,
9-
sync::{Arc, OnceLock},
9+
sync::{
10+
atomic::{AtomicBool, Ordering},
11+
Arc, OnceLock,
12+
},
1013
};
1114

1215
use pyo3::{prelude::*, PyTraverseError, PyVisit};
@@ -35,14 +38,19 @@ impl<T> Definitions<T> {
3538
}
3639

3740
/// Internal type which contains a definition to be filled
38-
pub struct Definition<T>(Arc<OnceLock<T>>);
41+
pub struct Definition<T>(Arc<DefinitionInner<T>>);
3942

4043
impl<T> Definition<T> {
4144
pub fn get(&self) -> Option<&T> {
42-
self.0.get()
45+
self.0.value.get()
4346
}
4447
}
4548

49+
struct DefinitionInner<T> {
50+
value: OnceLock<T>,
51+
name: LazyName,
52+
}
53+
4654
/// Reference to a definition.
4755
pub struct DefinitionRef<T> {
4856
name: Arc<String>,
@@ -64,12 +72,15 @@ impl<T> DefinitionRef<T> {
6472
Arc::as_ptr(&self.value.0) as usize
6573
}
6674

67-
pub fn name(&self) -> &str {
68-
&self.name
75+
pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
76+
match self.value.0.value.get() {
77+
Some(value) => self.value.0.name.get_or_init(|| init(value)),
78+
None => "...",
79+
}
6980
}
7081

7182
pub fn get(&self) -> Option<&T> {
72-
self.value.0.get()
83+
self.value.0.value.get()
7384
}
7485
}
7586

@@ -83,7 +94,17 @@ impl<T: Debug> Debug for DefinitionRef<T> {
8394

8495
impl<T: Debug> Debug for Definitions<T> {
8596
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86-
self.0.fmt(f)
97+
// Formatted as a list for backwards compatibility; in principle
98+
// this could be formatted as a map. Maybe change in a future
99+
// minor release of pydantic.
100+
write![f, "["]?;
101+
let mut first = true;
102+
for def in self.0.values() {
103+
write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?;
104+
first = false;
105+
}
106+
write![f, "]"]?;
107+
Ok(())
87108
}
88109
}
89110

@@ -95,7 +116,7 @@ impl<T> Clone for Definition<T> {
95116

96117
impl<T: Debug> Debug for Definition<T> {
97118
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98-
match self.0.get() {
119+
match self.0.value.get() {
99120
Some(value) => value.fmt(f),
100121
None => "...".fmt(f),
101122
}
@@ -104,7 +125,7 @@ impl<T: Debug> Debug for Definition<T> {
104125

105126
impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
106127
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
107-
if let Some(value) = self.value.0.get() {
128+
if let Some(value) = self.value.0.value.get() {
108129
value.py_gc_traverse(visit)?;
109130
}
110131
Ok(())
@@ -114,7 +135,7 @@ impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
114135
impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
115136
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
116137
for value in self.0.values() {
117-
if let Some(value) = value.0.get() {
138+
if let Some(value) = value.0.value.get() {
118139
value.py_gc_traverse(visit)?;
119140
}
120141
}
@@ -142,7 +163,10 @@ impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
142163
let name = Arc::new(reference.to_string());
143164
let value = match self.definitions.0.entry(name.clone()) {
144165
Entry::Occupied(entry) => entry.into_mut(),
145-
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(OnceLock::new()))),
166+
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
167+
value: OnceLock::new(),
168+
name: LazyName::new(),
169+
}))),
146170
};
147171
DefinitionRef {
148172
name,
@@ -156,23 +180,75 @@ impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
156180
let value = match self.definitions.0.entry(name.clone()) {
157181
Entry::Occupied(entry) => {
158182
let definition = entry.into_mut();
159-
match definition.0.set(value) {
183+
match definition.0.value.set(value) {
160184
Ok(()) => definition.clone(),
161185
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
162186
}
163187
}
164-
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(OnceLock::from(value)))).clone(),
188+
Entry::Vacant(entry) => entry
189+
.insert(Definition(Arc::new(DefinitionInner {
190+
value: OnceLock::from(value),
191+
name: LazyName::new(),
192+
})))
193+
.clone(),
165194
};
166195
Ok(DefinitionRef { name, value })
167196
}
168197

169198
/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
170199
pub fn finish(self) -> PyResult<Definitions<T>> {
171200
for (reference, def) in &self.definitions.0 {
172-
if def.0.get().is_none() {
201+
if def.0.value.get().is_none() {
173202
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
174203
}
175204
}
176205
Ok(self.definitions)
177206
}
178207
}
208+
209+
struct LazyName {
210+
initialized: OnceLock<String>,
211+
in_recursion: AtomicBool,
212+
}
213+
214+
impl LazyName {
215+
fn new() -> Self {
216+
Self {
217+
initialized: OnceLock::new(),
218+
in_recursion: AtomicBool::new(false),
219+
}
220+
}
221+
222+
/// Gets the validator name, returning the default in the case of recursion loops
223+
fn get_or_init(&self, init: impl FnOnce() -> String) -> &str {
224+
if let Some(s) = self.initialized.get() {
225+
return s.as_str();
226+
}
227+
228+
if self
229+
.in_recursion
230+
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
231+
.is_err()
232+
{
233+
return "...";
234+
}
235+
let result = self.initialized.get_or_init(init).as_str();
236+
self.in_recursion.store(false, Ordering::SeqCst);
237+
result
238+
}
239+
}
240+
241+
impl Debug for LazyName {
242+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243+
self.initialized.get().map_or("...", String::as_str).fmt(f)
244+
}
245+
}
246+
247+
impl Clone for LazyName {
248+
fn clone(&self) -> Self {
249+
Self {
250+
initialized: OnceLock::new(),
251+
in_recursion: AtomicBool::new(false),
252+
}
253+
}
254+
}

src/validators/definitions.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ impl BuildValidator for DefinitionRefValidator {
6464
let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?;
6565

6666
let definition = definitions.get_definition(schema_ref);
67-
68-
Ok(Self { definition }.into())
67+
Ok(Self::new(definition).into())
6968
}
7069
}
7170

@@ -131,8 +130,12 @@ impl Validator for DefinitionRefValidator {
131130

132131
let id = self as *const _ as usize;
133132
// have to unwrap here, because we can't return an error from this function, should be okay
134-
let validator = self.definition.get().unwrap();
135-
if RECURSION_SET.with(|set| set.borrow_mut().get_or_insert_with(HashSet::new).insert(id)) {
133+
let validator: &CombinedValidator = self.definition.get().unwrap();
134+
if RECURSION_SET.with(
135+
|set: &RefCell<Option<std::collections::HashSet<usize, ahash::RandomState>>>| {
136+
set.borrow_mut().get_or_insert_with(HashSet::new).insert(id)
137+
},
138+
) {
136139
let different_strict_behavior = validator.different_strict_behavior(ultra_strict);
137140
RECURSION_SET.with(|set| set.borrow_mut().get_or_insert_with(HashSet::new).remove(&id));
138141
different_strict_behavior
@@ -142,10 +145,9 @@ impl Validator for DefinitionRefValidator {
142145
}
143146

144147
fn get_name(&self) -> &str {
145-
self.definition.get().map_or("...", |validator| validator.get_name())
148+
self.definition.get_or_init_name(|v| v.get_name().into())
146149
}
147150

148-
/// don't need to call complete on the inner validator here, complete_validators takes care of that.
149151
fn complete(&self) -> PyResult<()> {
150152
Ok(())
151153
}

src/validators/list.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ impl Validator for ListValidator {
156156
fn complete(&self) -> PyResult<()> {
157157
if let Some(v) = &self.item_validator {
158158
v.complete()?;
159-
let inner_name = v.get_name();
160-
let _ = self.name.set(format!("{}[{inner_name}]", Self::EXPECTED_TYPE));
159+
let _ = self.name.set(format!("list[{}]", v.get_name()));
160+
} else {
161+
let _ = self.name.set("list[any]".into());
161162
}
162163
Ok(())
163164
}

src/validators/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ impl SchemaValidator {
116116
let mut definitions_builder = DefinitionsBuilder::new();
117117

118118
let validator = build_validator(schema, config, &mut definitions_builder)?;
119-
validator.complete()?;
120119
let definitions = definitions_builder.finish()?;
120+
validator.complete()?;
121121
for val in definitions.values() {
122122
val.get().unwrap().complete()?;
123123
}
@@ -387,8 +387,8 @@ impl<'py> SelfValidator<'py> {
387387
Ok(v) => v,
388388
Err(err) => return py_schema_err!("Error building self-schema:\n {}", err),
389389
};
390-
validator.complete()?;
391390
let definitions = definitions_builder.finish()?;
391+
validator.complete()?;
392392
for val in definitions.values() {
393393
val.get().unwrap().complete()?;
394394
}

tests/validators/test_definitions_recursive.py

+99-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import platform
23
from dataclasses import dataclass
34
from typing import List, Optional
@@ -243,7 +244,7 @@ class Branch:
243244

244245

245246
def test_invalid_schema():
246-
with pytest.raises(SchemaError, match='Definitions error: attempted to use `Branch` before it was filled'):
247+
with pytest.raises(SchemaError, match='Definitions error: definition `Branch` was never filled'):
247248
SchemaValidator(
248249
{
249250
'type': 'list',
@@ -987,3 +988,100 @@ def test_cyclic_data_threeway() -> None:
987988
'input': cyclic_data,
988989
}
989990
]
991+
992+
993+
def test_complex_recursive_type() -> None:
994+
schema = core_schema.definitions_schema(
995+
core_schema.definition_reference_schema('JsonType'),
996+
[
997+
core_schema.nullable_schema(
998+
core_schema.union_schema(
999+
[
1000+
core_schema.list_schema(core_schema.definition_reference_schema('JsonType')),
1001+
core_schema.dict_schema(
1002+
core_schema.str_schema(), core_schema.definition_reference_schema('JsonType')
1003+
),
1004+
core_schema.str_schema(),
1005+
core_schema.int_schema(),
1006+
core_schema.float_schema(),
1007+
core_schema.bool_schema(),
1008+
]
1009+
),
1010+
ref='JsonType',
1011+
)
1012+
],
1013+
)
1014+
1015+
validator = SchemaValidator(schema)
1016+
1017+
with pytest.raises(ValidationError) as exc_info:
1018+
validator.validate_python({'a': datetime.date(year=1992, month=12, day=11)})
1019+
1020+
assert exc_info.value.errors(include_url=False) == [
1021+
{
1022+
'type': 'list_type',
1023+
'loc': ('list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]',),
1024+
'msg': 'Input should be a valid list',
1025+
'input': {'a': datetime.date(1992, 12, 11)},
1026+
},
1027+
{
1028+
'type': 'list_type',
1029+
'loc': ('dict[str,...]', 'a', 'list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]'),
1030+
'msg': 'Input should be a valid list',
1031+
'input': datetime.date(1992, 12, 11),
1032+
},
1033+
{
1034+
'type': 'dict_type',
1035+
'loc': ('dict[str,...]', 'a', 'dict[str,...]'),
1036+
'msg': 'Input should be a valid dictionary',
1037+
'input': datetime.date(1992, 12, 11),
1038+
},
1039+
{
1040+
'type': 'string_type',
1041+
'loc': ('dict[str,...]', 'a', 'str'),
1042+
'msg': 'Input should be a valid string',
1043+
'input': datetime.date(1992, 12, 11),
1044+
},
1045+
{
1046+
'type': 'int_type',
1047+
'loc': ('dict[str,...]', 'a', 'int'),
1048+
'msg': 'Input should be a valid integer',
1049+
'input': datetime.date(1992, 12, 11),
1050+
},
1051+
{
1052+
'type': 'float_type',
1053+
'loc': ('dict[str,...]', 'a', 'float'),
1054+
'msg': 'Input should be a valid number',
1055+
'input': datetime.date(1992, 12, 11),
1056+
},
1057+
{
1058+
'type': 'bool_type',
1059+
'loc': ('dict[str,...]', 'a', 'bool'),
1060+
'msg': 'Input should be a valid boolean',
1061+
'input': datetime.date(1992, 12, 11),
1062+
},
1063+
{
1064+
'type': 'string_type',
1065+
'loc': ('str',),
1066+
'msg': 'Input should be a valid string',
1067+
'input': {'a': datetime.date(1992, 12, 11)},
1068+
},
1069+
{
1070+
'type': 'int_type',
1071+
'loc': ('int',),
1072+
'msg': 'Input should be a valid integer',
1073+
'input': {'a': datetime.date(1992, 12, 11)},
1074+
},
1075+
{
1076+
'type': 'float_type',
1077+
'loc': ('float',),
1078+
'msg': 'Input should be a valid number',
1079+
'input': {'a': datetime.date(1992, 12, 11)},
1080+
},
1081+
{
1082+
'type': 'bool_type',
1083+
'loc': ('bool',),
1084+
'msg': 'Input should be a valid boolean',
1085+
'input': {'a': datetime.date(1992, 12, 11)},
1086+
},
1087+
]

0 commit comments

Comments
 (0)