11
11
import collections .abc
12
12
import functools
13
13
import itertools
14
+ import inspect
14
15
import os
15
16
import sys
16
17
import time
17
18
import traceback
19
+ import typing
18
20
import warnings
19
21
import weakref
20
22
@@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
133
135
:param str channel: Channel to listen on.
134
136
135
137
:param callable callback:
136
- A callable receiving the following arguments:
138
+ A callable or a coroutine function receiving the following
139
+ arguments:
137
140
**connection**: a Connection the callback is registered with;
138
141
**pid**: PID of the Postgres server that sent the notification;
139
142
**channel**: name of the channel the notification was sent to;
140
143
**payload**: the payload.
144
+
145
+ .. versionchanged:: 0.24.0
146
+ The ``callback`` argument may be a coroutine function.
141
147
"""
142
148
self ._check_open ()
143
149
if channel not in self ._listeners :
144
150
await self .fetch ('LISTEN {}' .format (utils ._quote_ident (channel )))
145
151
self ._listeners [channel ] = set ()
146
- self ._listeners [channel ].add (callback )
152
+ self ._listeners [channel ].add (_Callback . from_callable ( callback ) )
147
153
148
154
async def remove_listener (self , channel , callback ):
149
155
"""Remove a listening callback on the specified channel."""
150
156
if self .is_closed ():
151
157
return
152
158
if channel not in self ._listeners :
153
159
return
154
- if callback not in self ._listeners [channel ]:
160
+ cb = _Callback .from_callable (callback )
161
+ if cb not in self ._listeners [channel ]:
155
162
return
156
- self ._listeners [channel ].remove (callback )
163
+ self ._listeners [channel ].remove (cb )
157
164
if not self ._listeners [channel ]:
158
165
del self ._listeners [channel ]
159
166
await self .fetch ('UNLISTEN {}' .format (utils ._quote_ident (channel )))
@@ -166,44 +173,51 @@ def add_log_listener(self, callback):
166
173
DEBUG, INFO, or LOG.
167
174
168
175
:param callable callback:
169
- A callable receiving the following arguments:
176
+ A callable or a coroutine function receiving the following
177
+ arguments:
170
178
**connection**: a Connection the callback is registered with;
171
179
**message**: the `exceptions.PostgresLogMessage` message.
172
180
173
181
.. versionadded:: 0.12.0
182
+
183
+ .. versionchanged:: 0.24.0
184
+ The ``callback`` argument may be a coroutine function.
174
185
"""
175
186
if self .is_closed ():
176
187
raise exceptions .InterfaceError ('connection is closed' )
177
- self ._log_listeners .add (callback )
188
+ self ._log_listeners .add (_Callback . from_callable ( callback ) )
178
189
179
190
def remove_log_listener (self , callback ):
180
191
"""Remove a listening callback for log messages.
181
192
182
193
.. versionadded:: 0.12.0
183
194
"""
184
- self ._log_listeners .discard (callback )
195
+ self ._log_listeners .discard (_Callback . from_callable ( callback ) )
185
196
186
197
def add_termination_listener (self , callback ):
187
198
"""Add a listener that will be called when the connection is closed.
188
199
189
200
:param callable callback:
190
- A callable receiving one argument:
201
+ A callable or a coroutine function receiving one argument:
191
202
**connection**: a Connection the callback is registered with.
192
203
193
204
.. versionadded:: 0.21.0
205
+
206
+ .. versionchanged:: 0.24.0
207
+ The ``callback`` argument may be a coroutine function.
194
208
"""
195
- self ._termination_listeners .add (callback )
209
+ self ._termination_listeners .add (_Callback . from_callable ( callback ) )
196
210
197
211
def remove_termination_listener (self , callback ):
198
212
"""Remove a listening callback for connection termination.
199
213
200
214
:param callable callback:
201
- The callable that was passed to
215
+ The callable or coroutine function that was passed to
202
216
:meth:`Connection.add_termination_listener`.
203
217
204
218
.. versionadded:: 0.21.0
205
219
"""
206
- self ._termination_listeners .discard (callback )
220
+ self ._termination_listeners .discard (_Callback . from_callable ( callback ) )
207
221
208
222
def get_server_pid (self ):
209
223
"""Return the PID of the Postgres server the connection is bound to."""
@@ -1430,8 +1444,11 @@ def _process_log_message(self, fields, last_query):
1430
1444
1431
1445
con_ref = self ._unwrap ()
1432
1446
for cb in self ._log_listeners :
1433
- self ._loop .call_soon (
1434
- self ._call_log_listener , cb , con_ref , message )
1447
+ if cb .is_async :
1448
+ self ._loop .create_task (cb .cb (con_ref , message ))
1449
+ else :
1450
+ self ._loop .call_soon (
1451
+ self ._call_log_listener , cb .cb , con_ref , message )
1435
1452
1436
1453
def _call_log_listener (self , cb , con_ref , message ):
1437
1454
try :
@@ -1449,16 +1466,19 @@ def _call_termination_listeners(self):
1449
1466
1450
1467
con_ref = self ._unwrap ()
1451
1468
for cb in self ._termination_listeners :
1452
- try :
1453
- cb (con_ref )
1454
- except Exception as ex :
1455
- self ._loop .call_exception_handler ({
1456
- 'message' : (
1457
- 'Unhandled exception in asyncpg connection '
1458
- 'termination listener callback {!r}' .format (cb )
1459
- ),
1460
- 'exception' : ex
1461
- })
1469
+ if cb .is_async :
1470
+ self ._loop .create_task (cb .cb (con_ref ))
1471
+ else :
1472
+ try :
1473
+ cb .cb (con_ref )
1474
+ except Exception as ex :
1475
+ self ._loop .call_exception_handler ({
1476
+ 'message' : (
1477
+ 'Unhandled exception in asyncpg connection '
1478
+ 'termination listener callback {!r}' .format (cb )
1479
+ ),
1480
+ 'exception' : ex
1481
+ })
1462
1482
1463
1483
self ._termination_listeners .clear ()
1464
1484
@@ -1468,8 +1488,11 @@ def _process_notification(self, pid, channel, payload):
1468
1488
1469
1489
con_ref = self ._unwrap ()
1470
1490
for cb in self ._listeners [channel ]:
1471
- self ._loop .call_soon (
1472
- self ._call_listener , cb , con_ref , pid , channel , payload )
1491
+ if cb .is_async :
1492
+ self ._loop .create_task (cb .cb (con_ref , pid , channel , payload ))
1493
+ else :
1494
+ self ._loop .call_soon (
1495
+ self ._call_listener , cb .cb , con_ref , pid , channel , payload )
1473
1496
1474
1497
def _call_listener (self , cb , con_ref , pid , channel , payload ):
1475
1498
try :
@@ -2154,6 +2177,26 @@ def _maybe_cleanup(self):
2154
2177
self ._on_remove (old_entry ._statement )
2155
2178
2156
2179
2180
+ class _Callback (typing .NamedTuple ):
2181
+
2182
+ cb : typing .Callable [..., None ]
2183
+ is_async : bool
2184
+
2185
+ @classmethod
2186
+ def from_callable (cls , cb : typing .Callable [..., None ]) -> '_Callback' :
2187
+ if inspect .iscoroutinefunction (cb ):
2188
+ is_async = True
2189
+ elif callable (cb ):
2190
+ is_async = False
2191
+ else :
2192
+ raise exceptions .InterfaceError (
2193
+ 'expected a callable or an `async def` function,'
2194
+ 'got {!r}' .format (cb )
2195
+ )
2196
+
2197
+ return cls (cb , is_async )
2198
+
2199
+
2157
2200
class _Atomic :
2158
2201
__slots__ = ('_acquired' ,)
2159
2202
0 commit comments