Skip to content

Commit 0cc43a2

Browse files
committed
Fix handling of OIDs >= 2**31
Currently asyncpg (incorrectly) assumes OIDs to be signed 32-bit integers, whereas in reality they are unsigned. As a result, things would crash once the OID sequence reaches 2**31. Fix this by decoding OID values as unsigned longs. Fixes: #279
1 parent cf523be commit 0cc43a2

File tree

12 files changed

+170
-61
lines changed

12 files changed

+170
-61
lines changed

asyncpg/_testbase/__init__.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,21 @@ def _format_loop_exception(self, context, n):
205205
_default_cluster = None
206206

207207

208-
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
209-
initdb_options=None):
208+
def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
210209
cluster = ClusterCls(**cluster_kwargs)
211210
cluster.init(**(initdb_options or {}))
212211
cluster.trust_local_connections()
213-
cluster.start(port='dynamic', server_settings=server_settings)
214212
atexit.register(_shutdown_cluster, cluster)
215213
return cluster
216214

217215

216+
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
217+
initdb_options=None):
218+
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
219+
cluster.start(port='dynamic', server_settings=server_settings)
220+
return cluster
221+
222+
218223
def _get_initdb_options(initdb_options=None):
219224
if not initdb_options:
220225
initdb_options = {}
@@ -228,7 +233,7 @@ def _get_initdb_options(initdb_options=None):
228233
return initdb_options
229234

230235

231-
def _start_default_cluster(server_settings={}, initdb_options=None):
236+
def _init_default_cluster(initdb_options=None):
232237
global _default_cluster
233238

234239
if _default_cluster is None:
@@ -237,9 +242,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None):
237242
# Using existing cluster, assuming it is initialized and running
238243
_default_cluster = pg_cluster.RunningCluster()
239244
else:
240-
_default_cluster = _start_cluster(
245+
_default_cluster = _init_cluster(
241246
pg_cluster.TempCluster, cluster_kwargs={},
242-
server_settings=server_settings,
243247
initdb_options=_get_initdb_options(initdb_options))
244248

245249
return _default_cluster
@@ -248,7 +252,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None):
248252
def _shutdown_cluster(cluster):
249253
if cluster.get_status() == 'running':
250254
cluster.stop()
251-
cluster.destroy()
255+
if cluster.get_status() != 'not-initialized':
256+
cluster.destroy()
252257

253258

254259
def create_pool(dsn=None, *,
@@ -278,15 +283,40 @@ def get_server_settings(cls):
278283
'log_connections': 'on'
279284
}
280285

286+
@classmethod
287+
def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
288+
cluster = _init_cluster(ClusterCls, cluster_kwargs,
289+
_get_initdb_options(initdb_options))
290+
cls._clusters.append(cluster)
291+
return cluster
292+
293+
@classmethod
294+
def start_cluster(cls, cluster, *, server_settings={}):
295+
cluster.start(port='dynamic', server_settings=server_settings)
296+
281297
@classmethod
282298
def setup_cluster(cls):
283-
cls.cluster = _start_default_cluster(cls.get_server_settings())
299+
cls.cluster = _init_default_cluster()
300+
301+
if cls.cluster.get_status() != 'running':
302+
cls.cluster.start(
303+
port='dynamic', server_settings=cls.get_server_settings())
284304

285305
@classmethod
286306
def setUpClass(cls):
287307
super().setUpClass()
308+
cls._clusters = []
288309
cls.setup_cluster()
289310

311+
@classmethod
312+
def tearDownClass(cls):
313+
super().tearDownClass()
314+
for cluster in cls._clusters:
315+
if cluster is not _default_cluster:
316+
cluster.stop()
317+
cluster.destroy()
318+
cls._clusters = []
319+
290320
@classmethod
291321
def get_connection_spec(cls, kwargs={}):
292322
conn_spec = cls.cluster.get_connection_spec()
@@ -309,14 +339,6 @@ def connect(cls, **kwargs):
309339
conn_spec = cls.get_connection_spec(kwargs)
310340
return pg_connection.connect(**conn_spec, loop=cls.loop)
311341

312-
@classmethod
313-
def start_cluster(cls, ClusterCls, *,
314-
cluster_kwargs={}, server_settings={},
315-
initdb_options={}):
316-
return _start_cluster(
317-
ClusterCls, cluster_kwargs,
318-
server_settings, _get_initdb_options(initdb_options))
319-
320342

321343
class ProxiedClusterTestCase(ClusterTestCase):
322344
@classmethod

asyncpg/cluster.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def get_status(self):
106106
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
107107
stdout, stderr = process.stdout, process.stderr
108108

109-
if process.returncode == 4 or not os.listdir(self._data_dir):
109+
if (process.returncode == 4 or not os.path.exists(self._data_dir) or
110+
not os.listdir(self._data_dir)):
110111
return 'not-initialized'
111112
elif process.returncode == 3:
112113
return 'stopped'
@@ -299,6 +300,42 @@ def get_connection_spec(self):
299300
def override_connection_spec(self, **kwargs):
300301
self._connection_spec_override = kwargs
301302

303+
def reset_wal(self, *, oid=None, xid=None):
304+
status = self.get_status()
305+
if status == 'not-initialized':
306+
raise ClusterError(
307+
'cannot modify WAL status: cluster is not initialized')
308+
309+
if status == 'running':
310+
raise ClusterError(
311+
'cannot modify WAL status: cluster is running')
312+
313+
opts = []
314+
if oid is not None:
315+
opts.extend(['-o', str(oid)])
316+
if xid is not None:
317+
opts.extend(['-x', str(xid)])
318+
if not opts:
319+
return
320+
321+
opts.append(self._data_dir)
322+
323+
try:
324+
reset_wal = self._find_pg_binary('pg_resetwal')
325+
except ClusterError:
326+
reset_wal = self._find_pg_binary('pg_resetxlog')
327+
328+
process = subprocess.run(
329+
[reset_wal] + opts,
330+
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
331+
332+
stderr = process.stderr
333+
334+
if process.returncode != 0:
335+
raise ClusterError(
336+
'pg_resetwal exited with status {:d}: {}'.format(
337+
process.returncode, stderr.decode()))
338+
302339
def reset_hba(self):
303340
"""Remove all records from pg_hba.conf."""
304341
status = self.get_status()

asyncpg/connection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@ async def set_type_codec(self, typename, *,
912912
raise ValueError('unknown type: {}.{}'.format(schema, typename))
913913

914914
oid = typeinfo['oid']
915-
if typeinfo['kind'] != b'b' or typeinfo['elemtype']:
915+
if typeinfo['kind'] not in (b'b', b'd') or typeinfo['elemtype']:
916916
raise ValueError(
917917
'cannot use custom codec on non-scalar type {}.{}'.format(
918918
schema, typename))

asyncpg/protocol/buffer.pxd

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ cdef class ReadBuffer:
9797
cdef feed_data(self, data)
9898
cdef inline _ensure_first_buf(self)
9999
cdef _switch_to_next_buf(self)
100-
cdef inline read_byte(self)
100+
cdef inline char read_byte(self) except? -1
101101
cdef inline const char* _try_read_bytes(self, ssize_t nbytes)
102102
cdef inline _read(self, char *buf, ssize_t nbytes)
103103
cdef read(self, ssize_t nbytes)
104104
cdef inline const char* read_bytes(self, ssize_t n) except NULL
105-
cdef inline read_int32(self)
106-
cdef inline read_int16(self)
105+
cdef inline int32_t read_int32(self) except? -1
106+
cdef inline int16_t read_int16(self) except? -1
107107
cdef inline read_cstr(self)
108108
cdef int32_t has_message(self) except -1
109109
cdef inline int32_t has_message_type(self, char mtype) except -1

asyncpg/protocol/buffer.pyx

+3-3
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ cdef class ReadBuffer:
376376

377377
return Memory.new(buf, result, nbytes)
378378

379-
cdef inline read_byte(self):
379+
cdef inline char read_byte(self) except? -1:
380380
cdef const char *first_byte
381381

382382
if ASYNCPG_DEBUG:
@@ -404,7 +404,7 @@ cdef class ReadBuffer:
404404
mem = <Memory>(self.read(n))
405405
return mem.buf
406406

407-
cdef inline read_int32(self):
407+
cdef inline int32_t read_int32(self) except? -1:
408408
cdef:
409409
Memory mem
410410
const char *cbuf
@@ -417,7 +417,7 @@ cdef class ReadBuffer:
417417
mem = <Memory>(self.read(4))
418418
return hton.unpack_int32(mem.buf)
419419

420-
cdef inline read_int16(self):
420+
cdef inline int16_t read_int16(self) except? -1:
421421
cdef:
422422
Memory mem
423423
const char *cbuf

asyncpg/protocol/codecs/base.pyx

+23-7
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,22 @@ cdef codec_decode_func_ex(ConnectionSettings settings, FastReadBuffer buf,
373373
return (<Codec>arg).decode(settings, buf)
374374

375375

376+
cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFF:
377+
cdef:
378+
int64_t oid = 0
379+
bint overflow = False
380+
381+
try:
382+
oid = cpython.PyLong_AsLongLong(val)
383+
except OverflowError:
384+
overflow = True
385+
386+
if overflow or (oid < 0 or oid > 4294967295L):
387+
raise OverflowError('OID value too large: {!r}'.format(val))
388+
389+
return <uint32_t>val
390+
391+
376392
cdef class DataCodecConfig:
377393
def __init__(self, cache_key):
378394
try:
@@ -523,9 +539,10 @@ cdef class DataCodecConfig:
523539
Codec core_codec
524540
encode_func c_encoder = NULL
525541
decode_func c_decoder = NULL
542+
uint32_t oid = pylong_as_oid(typeoid)
526543

527544
if xformat == PG_XFORMAT_TUPLE:
528-
core_codec = get_any_core_codec(typeoid, format, xformat)
545+
core_codec = get_any_core_codec(oid, format, xformat)
529546
if core_codec is None:
530547
raise ValueError(
531548
"{} type does not support 'tuple' exchange format".format(
@@ -538,7 +555,7 @@ cdef class DataCodecConfig:
538555
self.remove_python_codec(typeoid, typename, typeschema)
539556

540557
self._local_type_codecs[typeoid] = \
541-
Codec.new_python_codec(typeoid, typename, typeschema, typekind,
558+
Codec.new_python_codec(oid, typename, typeschema, typekind,
542559
encoder, decoder, c_encoder, c_decoder,
543560
format, xformat)
544561

@@ -551,19 +568,18 @@ cdef class DataCodecConfig:
551568
cdef:
552569
Codec codec
553570
Codec target_codec
571+
uint32_t oid = pylong_as_oid(typeoid)
572+
uint32_t alias_pid
554573

555574
if format == PG_FORMAT_ANY:
556575
formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT)
557576
else:
558577
formats = (format,)
559578

560579
for format in formats:
561-
if self.get_codec(typeoid, format) is not None:
562-
raise ValueError('cannot override codec for type {}'.format(
563-
typeoid))
564-
565580
if isinstance(alias_to, int):
566-
target_codec = self.get_codec(alias_to, format)
581+
alias_oid = pylong_as_oid(alias_to)
582+
target_codec = self.get_codec(alias_oid, format)
567583
else:
568584
target_codec = get_extra_codec(alias_to, format)
569585

asyncpg/protocol/codecs/int.pyx

+5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ cdef int4_decode(ConnectionSettings settings, FastReadBuffer buf):
6262
return cpython.PyLong_FromLong(hton.unpack_int32(buf.read(4)))
6363

6464

65+
cdef uint4_decode(ConnectionSettings settings, FastReadBuffer buf):
66+
return cpython.PyLong_FromUnsignedLong(
67+
<uint32_t>hton.unpack_int32(buf.read(4)))
68+
69+
6570
cdef int8_encode(ConnectionSettings settings, WriteBuffer buf, obj):
6671
cdef int overflow = 0
6772
cdef long long val

asyncpg/protocol/codecs/misc.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ cdef init_pseudo_codecs():
3737
for oid_type in oid_types:
3838
register_core_codec(oid_type,
3939
<encode_func>&int4_encode,
40-
<decode_func>&int4_decode,
40+
<decode_func>&uint4_decode,
4141
PG_FORMAT_BINARY)
4242

4343
# reg* types -- these are really system catalog OIDs, but

asyncpg/protocol/prepared_stmt.pyx

+10-10
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ cdef class PreparedStatementState:
163163
list cols_names
164164
object cols_mapping
165165
tuple row
166-
int oid
166+
uint32_t oid
167167
Codec codec
168168
list codecs
169169

@@ -183,7 +183,7 @@ cdef class PreparedStatementState:
183183
cols_mapping[col_name] = i
184184
cols_names.append(col_name)
185185
oid = row[3]
186-
codec = self.settings.get_data_codec(<uint32_t>oid)
186+
codec = self.settings.get_data_codec(oid)
187187
if codec is None or not codec.has_decoder():
188188
raise RuntimeError('no decoder for OID {}'.format(oid))
189189
if not codec.is_binary():
@@ -198,7 +198,7 @@ cdef class PreparedStatementState:
198198

199199
cdef _ensure_args_encoder(self):
200200
cdef:
201-
int p_oid
201+
uint32_t p_oid
202202
Codec codec
203203
list codecs = []
204204

@@ -207,7 +207,7 @@ cdef class PreparedStatementState:
207207

208208
for i from 0 <= i < self.args_num:
209209
p_oid = self.parameters_desc[i]
210-
codec = self.settings.get_data_codec(<uint32_t>p_oid)
210+
codec = self.settings.get_data_codec(p_oid)
211211
if codec is None or not codec.has_encoder():
212212
raise RuntimeError('no encoder for OID {}'.format(p_oid))
213213
if codec.type not in {}:
@@ -290,14 +290,14 @@ cdef _decode_parameters_desc(object desc):
290290
cdef:
291291
ReadBuffer reader
292292
int16_t nparams
293-
int32_t p_oid
293+
uint32_t p_oid
294294
list result = []
295295

296296
reader = ReadBuffer.new_message_parser(desc)
297297
nparams = reader.read_int16()
298298

299299
for i from 0 <= i < nparams:
300-
p_oid = reader.read_int32()
300+
p_oid = <uint32_t>reader.read_int32()
301301
result.append(p_oid)
302302

303303
return result
@@ -310,9 +310,9 @@ cdef _decode_row_desc(object desc):
310310
int16_t nfields
311311

312312
bytes f_name
313-
int32_t f_table_oid
313+
uint32_t f_table_oid
314314
int16_t f_column_num
315-
int32_t f_dt_oid
315+
uint32_t f_dt_oid
316316
int16_t f_dt_size
317317
int32_t f_dt_mod
318318
int16_t f_format
@@ -325,9 +325,9 @@ cdef _decode_row_desc(object desc):
325325

326326
for i from 0 <= i < nfields:
327327
f_name = reader.read_cstr()
328-
f_table_oid = reader.read_int32()
328+
f_table_oid = <uint32_t>reader.read_int32()
329329
f_column_num = reader.read_int16()
330-
f_dt_oid = reader.read_int32()
330+
f_dt_oid = <uint32_t>reader.read_int32()
331331
f_dt_size = reader.read_int16()
332332
f_dt_mod = reader.read_int32()
333333
f_format = reader.read_int16()

0 commit comments

Comments
 (0)