Skip to content

Commit 4bdb32a

Browse files
committed
Initialize statement codecs immediately after Prepare
Currently the statement codecs are populated just before the first Bind is issued. This is problematic as in the time since Prepare, the codec cache for derived types (arrays, composites etc.) may have been purged by an installation of a custom codec, or general schema state invalidation. Fix this by populating the codecs immediately after the statement data types have been resolved. Fixes: #241.
1 parent a3b4066 commit 4bdb32a

File tree

5 files changed

+51
-11
lines changed

5 files changed

+51
-11
lines changed

asyncpg/connection.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
289289
types, intro_stmt = await self.__execute(
290290
self._intro_query, (list(ready),), 0, timeout)
291291
self._protocol.get_settings().register_data_types(types)
292-
if not intro_stmt.name and not statement.name:
293-
# The introspection query has used an anonymous statement,
294-
# which has blown away the anonymous statement we've prepared
295-
# for the query, so we need to re-prepare it.
296-
statement = await self._protocol.prepare(
297-
stmt_name, query, timeout)
292+
# The introspection query has used an anonymous statement,
293+
# which has blown away the anonymous statement we've prepared
294+
# for the query, so we need to re-prepare it.
295+
need_reprepare = not intro_stmt.name and not statement.name
296+
else:
297+
need_reprepare = False
298+
299+
# Now that types have been resolved, populate the codec pipeline
300+
# for the statement.
301+
statement._init_codecs()
302+
303+
if need_reprepare:
304+
await self._protocol.prepare(
305+
stmt_name, query, timeout, state=statement)
298306

299307
if use_cache:
300308
self._stmt_cache.put(query, statement)

asyncpg/protocol/prepared_stmt.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cdef class PreparedStatementState:
3030
tuple rows_codecs
3131

3232
cdef _encode_bind_msg(self, args)
33+
cpdef _init_codecs(self)
3334
cdef _ensure_rows_decoder(self)
3435
cdef _ensure_args_encoder(self)
3536
cdef _set_row_desc(self, object desc)

asyncpg/protocol/prepared_stmt.pyx

+4-3
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ cdef class PreparedStatementState:
8282
else:
8383
return True
8484

85+
cpdef _init_codecs(self):
86+
self._ensure_args_encoder()
87+
self._ensure_rows_decoder()
88+
8589
def attach(self):
8690
self.refs += 1
8791

@@ -101,9 +105,6 @@ cdef class PreparedStatementState:
101105
raise exceptions.InterfaceError(
102106
'the number of query arguments cannot exceed 32767')
103107

104-
self._ensure_args_encoder()
105-
self._ensure_rows_decoder()
106-
107108
writer = WriteBuffer.new()
108109

109110
num_args_passed = len(args)

asyncpg/protocol/protocol.pyx

+5-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ cdef class BaseProtocol(CoreProtocol):
146146
self.is_reading = False
147147
self.transport.pause_reading()
148148

149-
async def prepare(self, stmt_name, query, timeout):
149+
async def prepare(self, stmt_name, query, timeout,
150+
PreparedStatementState state=None):
150151
if self.cancel_waiter is not None:
151152
await self.cancel_waiter
152153
if self.cancel_sent_waiter is not None:
@@ -160,7 +161,9 @@ cdef class BaseProtocol(CoreProtocol):
160161
try:
161162
self._prepare(stmt_name, query) # network op
162163
self.last_query = query
163-
self.statement = PreparedStatementState(stmt_name, query, self)
164+
if state is None:
165+
state = PreparedStatementState(stmt_name, query, self)
166+
self.statement = state
164167
except Exception as ex:
165168
waiter.set_exception(ex)
166169
self._coreproto_error()

tests/test_introspection.py

+27
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8+
import json
9+
810
from asyncpg import _testbase as tb
911
from asyncpg import connection as apg_con
1012

@@ -98,3 +100,28 @@ async def test_introspection_no_stmt_cache_03(self):
98100
"SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2])
99101

100102
self.assertEqual(apg_con._uid, old_uid + 1)
103+
104+
async def test_introspection_sticks_for_ps(self):
105+
# Test that the introspected codec pipeline for a prepared
106+
# statement is not affected by a subsequent codec cache bust.
107+
108+
ps = await self.con.prepare('SELECT $1::json[]')
109+
110+
try:
111+
# Setting a custom codec blows the codec cache for derived types.
112+
await self.con.set_type_codec(
113+
'json', encoder=lambda v: v, decoder=json.loads,
114+
schema='pg_catalog', format='text'
115+
)
116+
117+
# The originally prepared statement should still be OK and
118+
# use the previously selected codec.
119+
self.assertEqual(await ps.fetchval(['{"foo": 1}']), ['{"foo": 1}'])
120+
121+
# The new query uses the custom codec.
122+
v = await self.con.fetchval('SELECT $1::json[]', ['{"foo": 1}'])
123+
self.assertEqual(v, [{'foo': 1}])
124+
125+
finally:
126+
await self.con.reset_type_codec(
127+
'json', schema='pg_catalog')

0 commit comments

Comments
 (0)