Skip to content

Initialize statement codecs immediately after Prepare #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False,
types, intro_stmt = await self.__execute(
self._intro_query, (list(ready),), 0, timeout)
self._protocol.get_settings().register_data_types(types)
if not intro_stmt.name and not statement.name:
# The introspection query has used an anonymous statement,
# which has blown away the anonymous statement we've prepared
# for the query, so we need to re-prepare it.
statement = await self._protocol.prepare(
stmt_name, query, timeout)
# The introspection query has used an anonymous statement,
# which has blown away the anonymous statement we've prepared
# for the query, so we need to re-prepare it.
need_reprepare = not intro_stmt.name and not statement.name
else:
need_reprepare = False

# Now that types have been resolved, populate the codec pipeline
# for the statement.
statement._init_codecs()

if need_reprepare:
await self._protocol.prepare(
stmt_name, query, timeout, state=statement)

if use_cache:
self._stmt_cache.put(query, statement)
Expand Down
1 change: 1 addition & 0 deletions asyncpg/protocol/prepared_stmt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cdef class PreparedStatementState:
tuple rows_codecs

cdef _encode_bind_msg(self, args)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
cdef _set_row_desc(self, object desc)
Expand Down
7 changes: 4 additions & 3 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ cdef class PreparedStatementState:
else:
return True

cpdef _init_codecs(self):
self._ensure_args_encoder()
self._ensure_rows_decoder()

def attach(self):
self.refs += 1

Expand All @@ -101,9 +105,6 @@ cdef class PreparedStatementState:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')

self._ensure_args_encoder()
self._ensure_rows_decoder()

writer = WriteBuffer.new()

num_args_passed = len(args)
Expand Down
7 changes: 5 additions & 2 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ cdef class BaseProtocol(CoreProtocol):
self.is_reading = False
self.transport.pause_reading()

async def prepare(self, stmt_name, query, timeout):
async def prepare(self, stmt_name, query, timeout,
PreparedStatementState state=None):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
Expand All @@ -160,7 +161,9 @@ cdef class BaseProtocol(CoreProtocol):
try:
self._prepare(stmt_name, query) # network op
self.last_query = query
self.statement = PreparedStatementState(stmt_name, query, self)
if state is None:
state = PreparedStatementState(stmt_name, query, self)
self.statement = state
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
Expand Down
27 changes: 27 additions & 0 deletions tests/test_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import json

from asyncpg import _testbase as tb
from asyncpg import connection as apg_con

Expand Down Expand Up @@ -98,3 +100,28 @@ async def test_introspection_no_stmt_cache_03(self):
"SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2])

self.assertEqual(apg_con._uid, old_uid + 1)

async def test_introspection_sticks_for_ps(self):
# Test that the introspected codec pipeline for a prepared
# statement is not affected by a subsequent codec cache bust.

ps = await self.con._prepare('SELECT $1::json[]', use_cache=True)

try:
# Setting a custom codec blows the codec cache for derived types.
await self.con.set_type_codec(
'json', encoder=lambda v: v, decoder=json.loads,
schema='pg_catalog', format='text'
)

# The originally prepared statement should still be OK and
# use the previously selected codec.
self.assertEqual(await ps.fetchval(['{"foo": 1}']), ['{"foo": 1}'])

# The new query uses the custom codec.
v = await self.con.fetchval('SELECT $1::json[]', ['{"foo": 1}'])
self.assertEqual(v, [{'foo': 1}])

finally:
await self.con.reset_type_codec(
'json', schema='pg_catalog')