diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 8e841871..9b37b1a4 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -301,6 +301,13 @@ async def executemany(self, command: str, args, *, timeout: float=None): .. versionchanged:: 0.11.0 `timeout` became a keyword-only parameter. + + .. versionchanged:: 0.19.0 + The execution was changed to be in an implicit transaction if there + was no explicit transaction, so that it will no longer end up with + partial success. If you still need the previous behavior to + progressively execute many args, please use a loop with prepared + statement instead. """ self._check_open() return await self._executemany(command, args, timeout) @@ -821,6 +828,9 @@ async def _copy_in(self, copy_stmt, source, timeout): f = source elif isinstance(source, collections.abc.AsyncIterable): # assuming calling output returns an awaitable. + # copy_in() is designed to handle very large amounts of data, and + # the source async iterable is allowed to return an arbitrary + # amount of data on every iteration. reader = source else: # assuming source is an instance supporting the buffer protocol. diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 09a0a2ec..0e09a6e9 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -196,11 +196,24 @@ async def fetchrow(self, *args, timeout=None): return None return data[0] - async def __bind_execute(self, args, limit, timeout): + @connresource.guarded + async def executemany(self, args, *, timeout: float=None): + """Execute the statement for each sequence of arguments in *args*. + + :param args: An iterable containing sequences of arguments. + :param float timeout: Optional timeout value in seconds. + :return None: This method discards the results of the operations. + + .. versionadded:: 0.19.0 + """ + return await self.__do_execute( + lambda protocol: protocol.bind_execute_many( + self._state, args, '', timeout)) + + async def __do_execute(self, executor): protocol = self._connection._protocol try: - data, status, _ = await protocol.bind_execute( - self._state, args, '', limit, True, timeout) + return await executor(protocol) except exceptions.OutdatedSchemaCacheError: await self._connection.reload_schema_state() # We can not find all manually created prepared statements, so just @@ -209,6 +222,11 @@ async def __bind_execute(self, args, limit, timeout): # invalidate themselves (unfortunately, clearing caches again). self._state.mark_closed() raise + + async def __bind_execute(self, args, limit, timeout): + data, status, _ = await self.__do_execute( + lambda protocol: protocol.bind_execute( + self._state, args, '', limit, True, timeout)) self._last_status = status return data diff --git a/asyncpg/protocol/consts.pxi b/asyncpg/protocol/consts.pxi index 97cbbf35..e1f8726e 100644 --- a/asyncpg/protocol/consts.pxi +++ b/asyncpg/protocol/consts.pxi @@ -8,3 +8,5 @@ DEF _MAXINT32 = 2**31 - 1 DEF _COPY_BUFFER_SIZE = 524288 DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" +DEF _EXECUTE_MANY_BUF_NUM = 4 +DEF _EXECUTE_MANY_BUF_SIZE = 32768 diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index c96b1fa5..016316e4 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -83,11 +83,6 @@ cdef class CoreProtocol: bint _skip_discard bint _discard_data - # executemany support data - object _execute_iter - str _execute_portal_name - str _execute_stmt_name - ConnectionStatus con_status ProtocolState state TransactionStatus xact_status @@ -114,6 +109,7 @@ cdef class CoreProtocol: # True - completed, False - suspended bint result_execute_completed + cpdef is_in_transaction(self) cdef _process__auth(self, char mtype) cdef _process__prepare(self, char mtype) cdef _process__bind_execute(self, char mtype) @@ -146,6 +142,7 @@ cdef class CoreProtocol: cdef _auth_password_message_sasl_continue(self, bytes server_response) cdef _write(self, buf) + cdef _writelines(self, list buffers) cdef _read_server_messages(self) @@ -155,9 +152,13 @@ cdef class CoreProtocol: cdef _ensure_connected(self) + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query) cdef WriteBuffer _build_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data) + cdef WriteBuffer _build_empty_bind_data(self) + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit) cdef _connect(self) @@ -166,8 +167,11 @@ cdef class CoreProtocol: WriteBuffer bind_data, int32_t limit) cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) - cdef _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data) + cdef _execute_many_init(self) + cdef _execute_many_writelines(self, str portal_name, str stmt_name, + object bind_data) + cdef _execute_many_done(self, bint data_sent) + cdef _execute_many_fail(self, object error) cdef _bind(self, str portal_name, str stmt_name, WriteBuffer bind_data) cdef _execute(self, str portal_name, int32_t limit) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index a44bc5ad..2122b56d 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -27,13 +27,13 @@ cdef class CoreProtocol: # type of `scram` is `SCRAMAuthentcation` self.scram = None - # executemany support data - self._execute_iter = None - self._execute_portal_name = None - self._execute_stmt_name = None - self._reset_result() + cpdef is_in_transaction(self): + # PQTRANS_INTRANS = idle, within transaction block + # PQTRANS_INERROR = idle, within failed transaction + return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) + cdef _read_server_messages(self): cdef: char mtype @@ -258,22 +258,7 @@ cdef class CoreProtocol: elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() - if self.result_type == RESULT_FAILED: - self._push_result() - else: - try: - buf = next(self._execute_iter) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e - self._push_result() - else: - # Next iteration over the executemany() arg sequence - self._send_bind_message( - self._execute_portal_name, self._execute_stmt_name, - buf, 0) + self._push_result() elif mtype == b'I': # EmptyQueryResponse @@ -775,6 +760,17 @@ cdef class CoreProtocol: if self.con_status != CONNECTION_OK: raise apg_exc.InternalClientError('not connected') + cdef WriteBuffer _build_parse_message(self, str stmt_name, str query): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'P') + buf.write_str(stmt_name, self.encoding) + buf.write_str(query, self.encoding) + buf.write_int16(0) + + buf.end_message() + return buf + cdef WriteBuffer _build_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data): @@ -790,6 +786,25 @@ cdef class CoreProtocol: buf.end_message() return buf + cdef WriteBuffer _build_empty_bind_data(self): + cdef WriteBuffer buf + buf = WriteBuffer.new() + buf.write_int16(0) # The number of parameter format codes + buf.write_int16(0) # The number of parameter values + buf.write_int16(0) # The number of result-column format codes + return buf + + cdef WriteBuffer _build_execute_message(self, str portal_name, + int32_t limit): + cdef WriteBuffer buf + + buf = WriteBuffer.new_message(b'E') + buf.write_str(portal_name, self.encoding) # name of the portal + buf.write_int32(limit) # number of rows to return; 0 - all + + buf.end_message() + return buf + # API for subclasses cdef _connect(self): @@ -840,12 +855,7 @@ cdef class CoreProtocol: self._ensure_connected() self._set_state(PROTOCOL_PREPARE) - buf = WriteBuffer.new_message(b'P') - buf.write_str(stmt_name, self.encoding) - buf.write_str(query, self.encoding) - buf.write_int16(0) - buf.end_message() - packet = buf + packet = self._build_parse_message(stmt_name, query) buf = WriteBuffer.new_message(b'D') buf.write_byte(b'S') @@ -867,10 +877,7 @@ cdef class CoreProtocol: buf = self._build_bind_message(portal_name, stmt_name, bind_data) packet = buf - buf = WriteBuffer.new_message(b'E') - buf.write_str(portal_name, self.encoding) # name of the portal - buf.write_int32(limit) # number of rows to return; 0 - all - buf.end_message() + buf = self._build_execute_message(portal_name, limit) packet.write_buffer(buf) packet.write_bytes(SYNC_MESSAGE) @@ -889,30 +896,75 @@ cdef class CoreProtocol: self._send_bind_message(portal_name, stmt_name, bind_data, limit) - cdef _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data): - - cdef WriteBuffer buf - + cdef _execute_many_init(self): self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) self.result = None self._discard_data = True - self._execute_iter = bind_data - self._execute_portal_name = portal_name - self._execute_stmt_name = stmt_name - try: - buf = next(bind_data) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e + cdef _execute_many_writelines(self, str portal_name, str stmt_name, + object bind_data): + cdef: + WriteBuffer packet + WriteBuffer buf + list buffers = [] + + if self.result_type == RESULT_FAILED: + raise StopIteration(True) + + while len(buffers) < _EXECUTE_MANY_BUF_NUM: + packet = WriteBuffer.new() + + while packet.len() < _EXECUTE_MANY_BUF_SIZE: + try: + buf = next(bind_data) + except StopIteration: + if packet.len() > 0: + buffers.append(packet) + if len(buffers) > 0: + self._writelines(buffers) + raise StopIteration(True) + else: + raise StopIteration(False) + except Exception as ex: + raise StopIteration(ex) + packet.write_buffer( + self._build_bind_message(portal_name, stmt_name, buf)) + packet.write_buffer( + self._build_execute_message(portal_name, 0)) + buffers.append(packet) + self._writelines(buffers) + + cdef _execute_many_done(self, bint data_sent): + if data_sent: + self._write(SYNC_MESSAGE) + else: self._push_result() + + cdef _execute_many_fail(self, object error): + cdef WriteBuffer buf + + self.result_type = RESULT_FAILED + self.result = error + + # We shall rollback in an implicit transaction to prevent partial + # commit, while do nothing in an explicit transaction and leaving the + # error to the user + if self.is_in_transaction(): + self._execute_many_done(True) else: - self._send_bind_message(portal_name, stmt_name, buf, 0) + # Here if the implicit transaction is in `ignore_till_sync` mode, + # the `ROLLBACK` will be ignored and `Sync` will restore the state; + # or else the implicit transaction will be rolled back with a + # warning saying that there was no transaction, but rollback is + # done anyway, so we could ignore this warning. + buf = self._build_parse_message('', 'ROLLBACK') + buf.write_buffer(self._build_bind_message( + '', '', self._build_empty_bind_data())) + buf.write_buffer(self._build_execute_message('', 0)) + buf.write_bytes(SYNC_MESSAGE) + self._write(buf) cdef _execute(self, str portal_name, int32_t limit): cdef WriteBuffer buf @@ -922,10 +974,7 @@ cdef class CoreProtocol: self.result = [] - buf = WriteBuffer.new_message(b'E') - buf.write_str(portal_name, self.encoding) # name of the portal - buf.write_int32(limit) # number of rows to return; 0 - all - buf.end_message() + buf = self._build_execute_message(portal_name, limit) buf.write_bytes(SYNC_MESSAGE) @@ -1008,6 +1057,9 @@ cdef class CoreProtocol: cdef _write(self, buf): raise NotImplementedError + cdef _writelines(self, list buffers): + raise NotImplementedError + cdef _decode_row(self, const char* buf, ssize_t buf_len): pass diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index ac653bd0..880acf04 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -122,11 +122,6 @@ cdef class BaseProtocol(CoreProtocol): def get_settings(self): return self.settings - def is_in_transaction(self): - # PQTRANS_INTRANS = idle, within transaction block - # PQTRANS_INERROR = idle, within failed transaction - return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) - cdef inline resume_reading(self): if not self.is_reading: self.is_reading = True @@ -207,6 +202,7 @@ cdef class BaseProtocol(CoreProtocol): self._check_state() timeout = self._get_timeout_impl(timeout) + timer = Timer(timeout) # Make sure the argument sequence is encoded lazily with # this generator expression to keep the memory pressure under @@ -216,15 +212,42 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: - self._bind_execute_many( - portal_name, - state.name, - arg_bufs) # network op - + self._execute_many_init() self.last_query = state.query self.statement = state self.return_extra = False self.queries_count += 1 + + data_sent = False + while True: + self._execute_many_writelines( + portal_name, + state.name, + arg_bufs) # network op + data_sent = True + await asyncio.wait_for( + self.writing_allowed.wait(), + timeout=timer.get_remaining_budget()) + # On Windows the above event somehow won't allow context + # switch, so forcing one with sleep(0) here + await asyncio.sleep(0) + if not timer.has_budget_greater_than(0): + raise asyncio.TimeoutError + except StopIteration as ex: + reason = ex.value + if (not timer.has_budget_greater_than(0) + and not isinstance(reason, BaseException)): + reason = asyncio.TimeoutError + + if reason is True: + # there was data sent to DB + self._execute_many_done(True) # network op + elif reason is False: + # no data was sent to DB in last loop, use data_sent + self._execute_many_done(data_sent) # network op + else: + # data source raised an exception + self._execute_many_fail(reason) # network op except Exception as ex: waiter.set_exception(ex) self._coreproto_error() @@ -880,6 +903,9 @@ cdef class BaseProtocol(CoreProtocol): cdef _write(self, buf): self.transport.write(memoryview(buf)) + cdef _writelines(self, list buffers): + self.transport.writelines(buffers) + # asyncio callbacks: def data_received(self, data): @@ -932,6 +958,13 @@ class Timer: def get_remaining_budget(self): return self._budget + def has_budget_greater_than(self, amount): + if self._budget is None: + # Unlimited budget. + return True + else: + return self._budget > amount + class Protocol(BaseProtocol, asyncio.Protocol): pass diff --git a/tests/test_execute.py b/tests/test_execute.py index ccde0993..88ef2b36 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -9,6 +9,7 @@ import asyncpg from asyncpg import _testbase as tb +from asyncpg.exceptions import UniqueViolationError class TestExecuteScript(tb.ConnectedTestCase): @@ -97,57 +98,162 @@ async def test_execute_script_interrupted_terminate(self): self.con.terminate() - async def test_execute_many_1(self): - await self.con.execute('CREATE TEMP TABLE exmany (a text, b int)') - try: - result = await self.con.executemany(''' - INSERT INTO exmany VALUES($1, $2) - ''', [ - ('a', 1), ('b', 2), ('c', 3), ('d', 4) - ]) +class TestExecuteMany(tb.ConnectedTestCase): + def setUp(self): + super().setUp() + self.loop.run_until_complete(self.con.execute( + 'CREATE TEMP TABLE exmany (a text, b int PRIMARY KEY)')) - self.assertIsNone(result) + def tearDown(self): + self.loop.run_until_complete(self.con.execute('DROP TABLE exmany')) + super().tearDown() - result = await self.con.fetch(''' - SELECT * FROM exmany - ''') + async def test_executemany_basic(self): + result = await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) - self.assertEqual(result, [ - ('a', 1), ('b', 2), ('c', 3), ('d', 4) - ]) + self.assertIsNone(result) - # Empty set - result = await self.con.executemany(''' - INSERT INTO exmany VALUES($1, $2) - ''', ()) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') - result = await self.con.fetch(''' - SELECT * FROM exmany - ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) - self.assertEqual(result, [ - ('a', 1), ('b', 2), ('c', 3), ('d', 4) - ]) - finally: - await self.con.execute('DROP TABLE exmany') + # Empty set + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', ()) - async def test_execute_many_2(self): - await self.con.execute('CREATE TEMP TABLE exmany (b int)') + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') - try: - bad_data = ([1 / 0] for v in range(10)) + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) - with self.assertRaises(ZeroDivisionError): - async with self.con.transaction(): - await self.con.executemany(''' - INSERT INTO exmany VALUES($1) - ''', bad_data) + async def test_executemany_bad_input(self): + bad_data = ([1 / 0] for v in range(10)) - good_data = ([v] for v in range(10)) + with self.assertRaises(ZeroDivisionError): async with self.con.transaction(): await self.con.executemany(''' - INSERT INTO exmany VALUES($1) - ''', good_data) - finally: - await self.con.execute('DROP TABLE exmany') + INSERT INTO exmany (b)VALUES($1) + ''', bad_data) + + good_data = ([v] for v in range(10)) + async with self.con.transaction(): + await self.con.executemany(''' + INSERT INTO exmany (b)VALUES($1) + ''', good_data) + + async def test_executemany_server_failure(self): + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [ + ('a', 1), ('b', 2), ('c', 2), ('d', 4) + ]) + result = await self.con.fetch('SELECT * FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_server_failure_after_writes(self): + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', [('a' * 32768, x) for x in range(10)] + [ + ('b', 12), ('c', 12), ('d', 14) + ]) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_server_failure_during_writes(self): + # failure at the beginning, server error detected in the middle + pos = 0 + + def gen(): + nonlocal pos + while pos < 128: + pos += 1 + if pos < 3: + yield ('a', 0) + else: + yield 'a' * 32768, pos + + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', gen()) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + self.assertLess(pos, 128, 'should stop early') + + async def test_executemany_client_failure_after_writes(self): + with self.assertRaises(ZeroDivisionError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', (('a' * 32768, y + y / y) for y in range(10, -1, -1))) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_timeout(self): + with self.assertRaises(asyncio.TimeoutError): + await self.con.executemany(''' + INSERT INTO exmany VALUES(pg_sleep(0.1), $1) + ''', [[x] for x in range(128)], timeout=0.5) + result = await self.con.fetch('SELECT * FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_client_failure_in_transaction(self): + tx = self.con.transaction() + await tx.start() + with self.assertRaises(ZeroDivisionError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, $2) + ''', (('a' * 32768, y + y / y) for y in range(10, -1, -1))) + result = await self.con.fetch('SELECT b FROM exmany') + # only 2 batches executed (2 x 4) + self.assertEqual( + [x[0] for x in result], [y + 1 for y in range(10, 2, -1)]) + await tx.rollback() + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_client_server_failure_conflict(self): + self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64) + with self.assertRaises(UniqueViolationError): + await self.con.executemany(''' + INSERT INTO exmany VALUES($1, 0) + ''', (('a' * 32768,) for y in range(4, -1, -1) if y / y)) + result = await self.con.fetch('SELECT b FROM exmany') + self.assertEqual(result, []) + + async def test_executemany_prepare(self): + stmt = await self.con.prepare(''' + INSERT INTO exmany VALUES($1, $2) + ''') + result = await stmt.executemany([ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + self.assertIsNone(result) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ]) + # Empty set + await stmt.executemany(()) + result = await self.con.fetch(''' + SELECT * FROM exmany + ''') + self.assertEqual(result, [ + ('a', 1), ('b', 2), ('c', 3), ('d', 4) + ])