Skip to content

Commit 482a186

Browse files
committed
Fix type codec cache races
Current global type codec cache works poorly in a pooled environment. The global nature of the cache makes introspection/cache bust race a frequent occurrence. Additionally, busting the codec cache in _all_ connections only because one of them had reconfigured a codec seems wrong. The fix is simple: every connection now has its own codec cache. The downside is that there will be more introspection queries on fresh connections, but given that most connections in the field are pooled, the robustness gains are more important. Fixes: #278
1 parent 7a0585a commit 482a186

File tree

7 files changed

+120
-37
lines changed

7 files changed

+120
-37
lines changed

asyncpg/connection.py

+41-10
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,34 @@ async def _get_statement(self, query, timeout, *, named: bool=False,
296296
stmt_name = ''
297297

298298
statement = await self._protocol.prepare(stmt_name, query, timeout)
299-
ready = statement._init_types()
300-
if ready is not True:
301-
types, intro_stmt = await self.__execute(
302-
self._intro_query, (list(ready),), 0, timeout)
303-
self._protocol.get_settings().register_data_types(types)
299+
need_reprepare = False
300+
types_with_missing_codecs = statement._init_types()
301+
tries = 0
302+
while types_with_missing_codecs:
303+
settings = self._protocol.get_settings()
304+
305+
# Introspect newly seen types and populate the
306+
# codec cache.
307+
types, intro_stmt = await self._introspect_types(
308+
types_with_missing_codecs, timeout)
309+
310+
settings.register_data_types(types)
311+
304312
# The introspection query has used an anonymous statement,
305313
# which has blown away the anonymous statement we've prepared
306314
# for the query, so we need to re-prepare it.
307315
need_reprepare = not intro_stmt.name and not statement.name
308-
else:
309-
need_reprepare = False
316+
types_with_missing_codecs = statement._init_types()
317+
tries += 1
318+
if tries > 5:
319+
# In the vast majority of cases there will be only
320+
# one iteration. In rare cases, there might be a race
321+
# with reload_schema_state(), which would cause a
322+
# second try. More than five is clearly a bug.
323+
raise exceptions.InternalClientError(
324+
'could not resolve query result and/or argument types '
325+
'in {} attempts'.format(tries)
326+
)
310327

311328
# Now that types have been resolved, populate the codec pipeline
312329
# for the statement.
@@ -326,6 +343,10 @@ async def _get_statement(self, query, timeout, *, named: bool=False,
326343

327344
return statement
328345

346+
async def _introspect_types(self, typeoids, timeout):
347+
return await self.__execute(
348+
self._intro_query, (list(typeoids),), 0, timeout)
349+
329350
def cursor(self, query, *args, prefetch=None, timeout=None):
330351
"""Return a *cursor factory* for the specified query.
331352
@@ -1271,6 +1292,18 @@ def _drop_global_statement_cache(self):
12711292
else:
12721293
self._drop_local_statement_cache()
12731294

1295+
def _drop_local_type_cache(self):
1296+
self._protocol.get_settings().clear_type_cache()
1297+
1298+
def _drop_global_type_cache(self):
1299+
if self._proxy is not None:
1300+
# This connection is a member of a pool, so we delegate
1301+
# the cache drop to the pool.
1302+
pool = self._proxy._holder._pool
1303+
pool._drop_type_cache()
1304+
else:
1305+
self._drop_local_type_cache()
1306+
12741307
async def reload_schema_state(self):
12751308
"""Indicate that the database schema information must be reloaded.
12761309
@@ -1313,9 +1346,7 @@ async def reload_schema_state(self):
13131346
13141347
.. versionadded:: 0.14.0
13151348
"""
1316-
# It is enough to clear the type cache only once, not in each
1317-
# connection in the pool.
1318-
self._protocol.get_settings().clear_type_cache()
1349+
self._drop_global_type_cache()
13191350
self._drop_global_statement_cache()
13201351

13211352
async def _execute(self, query, args, limit, timeout, return_status=False):

asyncpg/pool.py

+6
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,12 @@ def _drop_statement_cache(self):
614614
if ch._con is not None:
615615
ch._con._drop_local_statement_cache()
616616

617+
def _drop_type_cache(self):
618+
# Drop type codec cache for all connections in the pool.
619+
for ch in self._holders:
620+
if ch._con is not None:
621+
ch._con._drop_local_type_cache()
622+
617623
def __await__(self):
618624
return self._async__init__().__await__()
619625

asyncpg/protocol/codecs/base.pxd

+2-2
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ cdef class Codec:
167167

168168
cdef class DataCodecConfig:
169169
cdef:
170-
dict _type_codecs_cache
171-
dict _local_type_codecs
170+
dict _derived_type_codecs
171+
dict _custom_type_codecs
172172

173173
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
174174
cdef inline Codec get_local_codec(self, uint32_t oid)

asyncpg/protocol/codecs/base.pyx

+15-17
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ from asyncpg.exceptions import OutdatedSchemaCacheError
1010

1111
cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2]
1212
cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2]
13-
cdef dict TYPE_CODECS_CACHE = {}
1413
cdef dict EXTRA_CODECS = {}
1514

1615

@@ -391,12 +390,11 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
391390

392391
cdef class DataCodecConfig:
393392
def __init__(self, cache_key):
394-
try:
395-
self._type_codecs_cache = TYPE_CODECS_CACHE[cache_key]
396-
except KeyError:
397-
self._type_codecs_cache = TYPE_CODECS_CACHE[cache_key] = {}
398-
399-
self._local_type_codecs = {}
393+
# Codec instance cache for derived types:
394+
# composites, arrays, ranges, domains and their combinations.
395+
self._derived_type_codecs = {}
396+
# Codec instances set up by the user for the connection.
397+
self._custom_type_codecs = {}
400398

401399
def add_types(self, types):
402400
cdef:
@@ -451,7 +449,7 @@ cdef class DataCodecConfig:
451449

452450
elem_delim = <Py_UCS4>ti['elemdelim'][0]
453451

454-
self._type_codecs_cache[oid, elem_format] = \
452+
self._derived_type_codecs[oid, elem_format] = \
455453
Codec.new_array_codec(
456454
oid, name, schema, elem_codec, elem_delim)
457455

@@ -483,7 +481,7 @@ cdef class DataCodecConfig:
483481
if has_text_elements:
484482
format = PG_FORMAT_TEXT
485483

486-
self._type_codecs_cache[oid, format] = \
484+
self._derived_type_codecs[oid, format] = \
487485
Codec.new_composite_codec(
488486
oid, name, schema, format, comp_elem_codecs,
489487
comp_type_attrs, element_names)
@@ -502,7 +500,7 @@ cdef class DataCodecConfig:
502500
elem_codec = self.declare_fallback_codec(
503501
base_type, name, schema)
504502

505-
self._type_codecs_cache[oid, format] = elem_codec
503+
self._derived_type_codecs[oid, format] = elem_codec
506504

507505
elif ti['kind'] == b'r':
508506
# Range type
@@ -523,7 +521,7 @@ cdef class DataCodecConfig:
523521
elem_codec = self.declare_fallback_codec(
524522
range_subtype_oid, name, schema)
525523

526-
self._type_codecs_cache[oid, elem_format] = \
524+
self._derived_type_codecs[oid, elem_format] = \
527525
Codec.new_range_codec(oid, name, schema, elem_codec)
528526

529527
elif ti['kind'] == b'e':
@@ -554,13 +552,13 @@ cdef class DataCodecConfig:
554552
# Clear all previous overrides (this also clears type cache).
555553
self.remove_python_codec(typeoid, typename, typeschema)
556554

557-
self._local_type_codecs[typeoid] = \
555+
self._custom_type_codecs[typeoid] = \
558556
Codec.new_python_codec(oid, typename, typeschema, typekind,
559557
encoder, decoder, c_encoder, c_decoder,
560558
format, xformat)
561559

562560
def remove_python_codec(self, typeoid, typename, typeschema):
563-
self._local_type_codecs.pop(typeoid, None)
561+
self._custom_type_codecs.pop(typeoid, None)
564562
self.clear_type_cache()
565563

566564
def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
@@ -592,7 +590,7 @@ cdef class DataCodecConfig:
592590
codec.schema = typeschema
593591
codec.kind = typekind
594592

595-
self._local_type_codecs[typeoid] = codec
593+
self._custom_type_codecs[typeoid] = codec
596594
break
597595
else:
598596
raise ValueError('unknown alias target: {}'.format(alias_to))
@@ -604,7 +602,7 @@ cdef class DataCodecConfig:
604602
self.clear_type_cache()
605603

606604
def clear_type_cache(self):
607-
self._type_codecs_cache.clear()
605+
self._derived_type_codecs.clear()
608606

609607
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
610608
cdef Codec codec
@@ -654,12 +652,12 @@ cdef class DataCodecConfig:
654652
return codec
655653
else:
656654
try:
657-
return self._type_codecs_cache[oid, format]
655+
return self._derived_type_codecs[oid, format]
658656
except KeyError:
659657
return None
660658

661659
cdef inline Codec get_local_codec(self, uint32_t oid):
662-
return self._local_type_codecs.get(oid)
660+
return self._custom_type_codecs.get(oid)
663661

664662

665663
cdef inline Codec get_core_codec(

asyncpg/protocol/prepared_stmt.pyx

+4-7
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,21 @@ cdef class PreparedStatementState:
6363
def _init_types(self):
6464
cdef:
6565
Codec codec
66-
set result = set()
66+
set missing = set()
6767

6868
if self.parameters_desc:
6969
for p_oid in self.parameters_desc:
7070
codec = self.settings.get_data_codec(<uint32_t>p_oid)
7171
if codec is None or not codec.has_encoder():
72-
result.add(p_oid)
72+
missing.add(p_oid)
7373

7474
if self.row_desc:
7575
for rdesc in self.row_desc:
7676
codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
7777
if codec is None or not codec.has_decoder():
78-
result.add(rdesc[3])
78+
missing.add(rdesc[3])
7979

80-
if len(result):
81-
return result
82-
else:
83-
return True
80+
return missing
8481

8582
cpdef _init_codecs(self):
8683
self._ensure_args_encoder()

asyncpg/protocol/settings.pxd

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ cdef class ConnectionSettings:
2525
cpdef inline clear_type_cache(self)
2626
cpdef inline set_builtin_type_codec(
2727
self, typeoid, typename, typeschema, typekind, alias_to)
28-
cpdef inline Codec get_data_codec(self, uint32_t oid, ServerDataFormat format=*)
28+
cpdef inline Codec get_data_codec(
29+
self, uint32_t oid, ServerDataFormat format=*)

tests/test_introspection.py

+50
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8+
import asyncio
89
import json
910

1011
from asyncpg import _testbase as tb
@@ -14,6 +15,16 @@
1415
MAX_RUNTIME = 0.1
1516

1617

18+
class SlowIntrospectionConnection(apg_con.Connection):
19+
"""Connection class to test introspection races."""
20+
introspect_count = 0
21+
22+
async def _introspect_types(self, *args, **kwargs):
23+
self.introspect_count += 1
24+
await asyncio.sleep(0.4, loop=self._loop)
25+
return await super()._introspect_types(*args, **kwargs)
26+
27+
1728
class TestIntrospection(tb.ConnectedTestCase):
1829
@classmethod
1930
def setUpClass(cls):
@@ -125,3 +136,42 @@ async def test_introspection_sticks_for_ps(self):
125136
finally:
126137
await self.con.reset_type_codec(
127138
'json', schema='pg_catalog')
139+
140+
async def test_introspection_retries_after_cache_bust(self):
141+
# Test that codec cache bust racing with the introspection
142+
# query would cause introspection to retry.
143+
slow_intro_conn = await self.connect(
144+
connection_class=SlowIntrospectionConnection)
145+
try:
146+
await self.con.execute('''
147+
CREATE DOMAIN intro_1_t AS int;
148+
CREATE DOMAIN intro_2_t AS int;
149+
''')
150+
151+
await slow_intro_conn.fetchval('''
152+
SELECT $1::intro_1_t
153+
''', 10)
154+
# slow_intro_conn cache is now populated with intro_1_t
155+
156+
async def wait_and_drop():
157+
await asyncio.sleep(0.1, loop=self.loop)
158+
await slow_intro_conn.reload_schema_state()
159+
160+
# Now, in parallel, run another query that
161+
# references both intro_1_t and intro_2_t.
162+
await asyncio.gather(
163+
slow_intro_conn.fetchval('''
164+
SELECT $1::intro_1_t, $2::intro_2_t
165+
''', 10, 20),
166+
wait_and_drop()
167+
)
168+
169+
# Initial query + two tries for the second query.
170+
self.assertEqual(slow_intro_conn.introspect_count, 3)
171+
172+
finally:
173+
await self.con.execute('''
174+
DROP DOMAIN intro_1_t;
175+
DROP DOMAIN intro_2_t;
176+
''')
177+
await slow_intro_conn.close()

0 commit comments

Comments
 (0)