Skip to content

Commit 3892f57

Browse files
committed
Support timeouts in Connection.close() and Pool.release()
Connection.close() and Pool.release() each gained the new timeout parameter. The pool.acquire() context manager now applies the passed timeout to __aexit__() as well. Connection.close() is now actually graceful. Instead of simply dropping the connection, it attempts to cancel the running query (if any), asks the server to terminate the connection and waits for the connection to terminate. To test all this properly, implement a TCP proxy, which emulates sudden connectivity loss (i.e. packets not reaching the server). Closes: #220
1 parent c04576d commit 3892f57

File tree

8 files changed

+421
-33
lines changed

8 files changed

+421
-33
lines changed

asyncpg/_testbase.py renamed to asyncpg/_testbase/__init__.py

+71-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from asyncpg import connection as pg_connection
2222
from asyncpg import pool as pg_pool
2323

24+
from . import fuzzer
25+
2426

2527
@contextlib.contextmanager
2628
def silence_asyncio_long_exec_warning():
@@ -36,6 +38,14 @@ def flt(log_record):
3638
logger.removeFilter(flt)
3739

3840

41+
def with_timeout(timeout):
42+
def wrap(func):
43+
func.__timeout__ = timeout
44+
return func
45+
46+
return wrap
47+
48+
3949
class TestCaseMeta(type(unittest.TestCase)):
4050

4151
@staticmethod
@@ -64,7 +74,11 @@ def __new__(mcls, name, bases, ns):
6474
for methname, meth in mcls._iter_methods(bases, ns):
6575
@functools.wraps(meth)
6676
def wrapper(self, *args, __meth__=meth, **kwargs):
67-
self.loop.run_until_complete(__meth__(self, *args, **kwargs))
77+
coro = __meth__(self, *args, **kwargs)
78+
timeout = getattr(meth, '__timeout__', 10.0)
79+
if timeout:
80+
coro = asyncio.wait_for(coro, timeout)
81+
self.loop.run_until_complete(coro)
6882
ns[methname] = wrapper
6983

7084
return super().__new__(mcls, name, bases, ns)
@@ -153,7 +167,8 @@ def _start_default_cluster(server_settings={}):
153167

154168

155169
def _shutdown_cluster(cluster):
156-
cluster.stop()
170+
if cluster.get_status() == 'running':
171+
cluster.stop()
157172
cluster.destroy()
158173

159174

@@ -193,17 +208,70 @@ def setUpClass(cls):
193208
super().setUpClass()
194209
cls.setup_cluster()
195210

211+
@classmethod
212+
def get_connection_spec(cls):
213+
return cls.cluster.get_connection_spec()
214+
196215
def create_pool(self, pool_class=pg_pool.Pool, **kwargs):
197-
conn_spec = self.cluster.get_connection_spec()
216+
conn_spec = self.get_connection_spec()
198217
conn_spec.update(kwargs)
199218
return create_pool(loop=self.loop, pool_class=pool_class, **conn_spec)
200219

220+
def connect(self, **kwargs):
221+
conn_spec = self.get_connection_spec()
222+
conn_spec.update(kwargs)
223+
return pg_connection.connect(**conn_spec, loop=self.loop)
224+
201225
@classmethod
202226
def start_cluster(cls, ClusterCls, *,
203227
cluster_kwargs={}, server_settings={}):
204228
return _start_cluster(ClusterCls, cluster_kwargs, server_settings)
205229

206230

231+
class ProxiedClusterTestCase(ClusterTestCase):
232+
@classmethod
233+
def get_server_settings(cls):
234+
settings = dict(super().get_server_settings())
235+
settings['listen_addresses'] = '127.0.0.1'
236+
return settings
237+
238+
@classmethod
239+
def get_proxy_settings(cls):
240+
return {'fuzzing-mode': None}
241+
242+
@classmethod
243+
def setUpClass(cls):
244+
super().setUpClass()
245+
conn_spec = cls.cluster.get_connection_spec()
246+
host = conn_spec.get('host')
247+
if not host:
248+
host = '127.0.0.1'
249+
elif host.startswith('/'):
250+
host = '127.0.0.1'
251+
cls.proxy = fuzzer.TCPFuzzingProxy(
252+
backend_host=host,
253+
backend_port=conn_spec['port'],
254+
loop=cls.loop
255+
)
256+
cls.loop.run_until_complete(cls.proxy.start())
257+
258+
@classmethod
259+
def tearDownClass(cls):
260+
cls.loop.run_until_complete(cls.proxy.stop())
261+
super().tearDownClass()
262+
263+
@classmethod
264+
def get_connection_spec(cls):
265+
conn_spec = cls.cluster.get_connection_spec()
266+
conn_spec['host'] = cls.proxy.listening_addr
267+
conn_spec['port'] = cls.proxy.listening_port
268+
return conn_spec
269+
270+
def tearDown(self):
271+
self.proxy.reset()
272+
super().tearDown()
273+
274+
207275
def with_connection_options(**options):
208276
if not options:
209277
raise ValueError('no connection options were specified')

asyncpg/_testbase/fuzzer.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright (C) 2016-present the asyncpg authors and contributors
2+
# <see AUTHORS file>
3+
#
4+
# This module is part of asyncpg and is released under
5+
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
6+
7+
8+
import asyncio
9+
import socket
10+
import typing
11+
12+
from asyncpg import cluster
13+
14+
15+
class StopServer(Exception):
16+
pass
17+
18+
19+
class TCPFuzzingProxy:
20+
def __init__(self, *, listening_addr: str='127.0.0.1',
21+
listening_port: typing.Optional[int]=None,
22+
backend_host: str, backend_port: int,
23+
settings: typing.Optional[dict]=None,
24+
loop: typing.Optional[asyncio.AbstractEventLoop]) -> None:
25+
self.listening_addr = listening_addr
26+
self.listening_port = listening_port
27+
self.backend_host = backend_host
28+
self.backend_port = backend_port
29+
self.settings = settings or {}
30+
self.loop = loop or asyncio.get_event_loop()
31+
self.connections = {}
32+
self.connectivity = asyncio.Event(loop=self.loop)
33+
self.connectivity.set()
34+
self.connectivity_loss = asyncio.Event(loop=self.loop)
35+
self.stop_event = asyncio.Event(loop=self.loop)
36+
self.sock = None
37+
38+
async def _wait(self, work):
39+
done, _ = await asyncio.wait([work, self.stop_event.wait()],
40+
return_when=asyncio.FIRST_COMPLETED,
41+
loop=self.loop)
42+
if self.stop_event.is_set():
43+
raise StopServer()
44+
else:
45+
return list(done)[0].result()
46+
47+
async def start(self):
48+
if self.listening_port is None:
49+
self.listening_port = cluster.find_available_port()
50+
51+
self.sock = socket.socket()
52+
self.sock.bind((self.listening_addr, self.listening_port))
53+
self.sock.listen(50)
54+
self.sock.setblocking(False)
55+
self.loop.create_task(self.listen())
56+
57+
async def stop(self):
58+
self.stop_event.set()
59+
for conn, conn_task in self.connections.items():
60+
conn_task.cancel()
61+
conn.close()
62+
self.sock.close()
63+
await asyncio.sleep(0.2, loop=self.loop)
64+
65+
async def listen(self):
66+
while True:
67+
try:
68+
client_sock, _ = await self._wait(
69+
self.loop.sock_accept(self.sock))
70+
71+
backend_sock = socket.socket()
72+
backend_sock.setblocking(False)
73+
74+
await self._wait(self.loop.sock_connect(
75+
backend_sock, (self.backend_host, self.backend_port)))
76+
except StopServer:
77+
break
78+
79+
conn = Connection(
80+
client_sock, backend_sock,
81+
connectivity=self.connectivity,
82+
connectivity_loss=self.connectivity_loss, loop=self.loop)
83+
conn_task = self.loop.create_task(conn.handle())
84+
self.connections[conn] = conn_task
85+
86+
def trigger_connectivity_loss(self):
87+
self.connectivity.clear()
88+
self.connectivity_loss.set()
89+
90+
def restore_connectivity(self):
91+
self.connectivity.set()
92+
self.connectivity_loss.clear()
93+
94+
def reset(self):
95+
self.restore_connectivity()
96+
97+
98+
class Connection:
99+
def __init__(self, client_sock, backend_sock,
100+
connectivity, connectivity_loss, loop):
101+
self.client_sock = client_sock
102+
self.backend_sock = backend_sock
103+
self.loop = loop
104+
self.connectivity = connectivity
105+
self.connectivity_loss = connectivity_loss
106+
self.proxy_to_backend_task = None
107+
self.proxy_from_backend_task = None
108+
109+
def close(self):
110+
if self.proxy_to_backend_task is not None:
111+
self.proxy_to_backend_task.cancel()
112+
113+
if self.proxy_from_backend_task is not None:
114+
self.proxy_from_backend_task.cancel()
115+
116+
async def handle(self):
117+
self.proxy_to_backend_task = self.loop.create_task(
118+
self.proxy_to_backend())
119+
120+
self.proxy_from_backend_task = self.loop.create_task(
121+
self.proxy_from_backend())
122+
123+
try:
124+
await asyncio.gather(
125+
self.proxy_to_backend_task, self.proxy_from_backend_task,
126+
loop=self.loop)
127+
finally:
128+
self.client_sock.close()
129+
self.backend_sock.close()
130+
131+
async def _read(self, sock, n):
132+
done, _ = await asyncio.wait([
133+
self.loop.sock_recv(sock, n), self.connectivity_loss.wait()],
134+
return_when=asyncio.FIRST_COMPLETED,
135+
loop=self.loop)
136+
if self.connectivity_loss.is_set():
137+
return None
138+
else:
139+
return list(done)[0].result()
140+
141+
async def _write(self, sock, data):
142+
done, _ = await asyncio.wait([
143+
self.loop.sock_sendall(sock, data), self.connectivity_loss.wait()],
144+
return_when=asyncio.FIRST_COMPLETED,
145+
loop=self.loop)
146+
if self.connectivity_loss.is_set():
147+
return
148+
else:
149+
return list(done)[0].result()
150+
151+
async def proxy_to_backend(self):
152+
try:
153+
buf = None
154+
155+
while True:
156+
await self.connectivity.wait()
157+
if buf is not None:
158+
data = buf
159+
buf = None
160+
else:
161+
data = await self._read(self.client_sock, 4096)
162+
if data == b'':
163+
self.close()
164+
break
165+
if self.connectivity_loss.is_set():
166+
if data:
167+
buf = data
168+
continue
169+
await self._write(self.backend_sock, data)
170+
finally:
171+
self.client_sock.shutdown()
172+
173+
async def proxy_from_backend(self):
174+
try:
175+
buf = None
176+
177+
while True:
178+
await self.connectivity.wait()
179+
if buf is not None:
180+
data = buf
181+
buf = None
182+
else:
183+
data = await self._read(self.backend_sock, 4096)
184+
if data == b'':
185+
self.close()
186+
break
187+
if self.connectivity_loss.is_set():
188+
if data:
189+
buf = data
190+
continue
191+
await self._write(self.client_sock, data)
192+
finally:
193+
self.backend_sock.shutdown()

asyncpg/connection.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -958,15 +958,22 @@ def is_closed(self):
958958
"""
959959
return not self._protocol.is_connected() or self._aborted
960960

961-
async def close(self):
962-
"""Close the connection gracefully."""
961+
async def close(self, *, timeout=None):
962+
"""Close the connection gracefully.
963+
964+
:param float timeout:
965+
Optional timeout value in seconds.
966+
967+
.. versionchanged:: 0.14.0
968+
Added the *timeout* parameter.
969+
"""
963970
if self.is_closed():
964971
return
965972
self._mark_stmts_as_closed()
966973
self._listeners.clear()
967974
self._log_listeners.clear()
968975
self._aborted = True
969-
await self._protocol.close()
976+
await self._protocol.close(timeout)
970977

971978
def terminate(self):
972979
"""Terminate the connection without waiting for pending data."""
@@ -976,13 +983,13 @@ def terminate(self):
976983
self._aborted = True
977984
self._protocol.abort()
978985

979-
async def reset(self):
986+
async def reset(self, *, timeout=None):
980987
self._check_open()
981988
self._listeners.clear()
982989
self._log_listeners.clear()
983990
reset_query = self._get_reset_query()
984991
if reset_query:
985-
await self.execute(reset_query)
992+
await self.execute(reset_query, timeout=timeout)
986993

987994
def _check_open(self):
988995
if self.is_closed():

0 commit comments

Comments
 (0)