Skip to content

Commit 01f659c

Browse files
authored
PYTHON-5071 Use one event loop for all asyncio tests (#2086)
1 parent 34ae214 commit 01f659c

File tree

2 files changed

+173
-67
lines changed

2 files changed

+173
-67
lines changed

test/__init__.py

+86-33
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import asyncio
1919
import gc
20+
import inspect
2021
import logging
2122
import multiprocessing
2223
import os
@@ -30,28 +31,6 @@
3031
import unittest
3132
import warnings
3233
from asyncio import iscoroutinefunction
33-
from test.helpers import (
34-
COMPRESSORS,
35-
IS_SRV,
36-
MONGODB_API_VERSION,
37-
MULTI_MONGOS_LB_URI,
38-
TEST_LOADBALANCER,
39-
TEST_SERVERLESS,
40-
TLS_OPTIONS,
41-
SystemCertsPatcher,
42-
client_knobs,
43-
db_pwd,
44-
db_user,
45-
global_knobs,
46-
host,
47-
is_server_resolvable,
48-
port,
49-
print_running_topology,
50-
print_thread_stacks,
51-
print_thread_tracebacks,
52-
sanitize_cmd,
53-
sanitize_reply,
54-
)
5534

5635
from pymongo.uri_parser import parse_uri
5736

@@ -63,7 +42,6 @@
6342
HAVE_IPADDRESS = False
6443
from contextlib import contextmanager
6544
from functools import partial, wraps
66-
from test.version import Version
6745
from typing import Any, Callable, Dict, Generator, overload
6846
from unittest import SkipTest
6947
from urllib.parse import quote_plus
@@ -78,6 +56,32 @@
7856
from pymongo.synchronous.database import Database
7957
from pymongo.synchronous.mongo_client import MongoClient
8058

59+
sys.path[0:0] = [""]
60+
61+
from test.helpers import (
62+
COMPRESSORS,
63+
IS_SRV,
64+
MONGODB_API_VERSION,
65+
MULTI_MONGOS_LB_URI,
66+
TEST_LOADBALANCER,
67+
TEST_SERVERLESS,
68+
TLS_OPTIONS,
69+
SystemCertsPatcher,
70+
client_knobs,
71+
db_pwd,
72+
db_user,
73+
global_knobs,
74+
host,
75+
is_server_resolvable,
76+
port,
77+
print_running_topology,
78+
print_thread_stacks,
79+
print_thread_tracebacks,
80+
sanitize_cmd,
81+
sanitize_reply,
82+
)
83+
from test.version import Version
84+
8185
_IS_SYNC = True
8286

8387

@@ -863,18 +867,66 @@ def max_message_size_bytes(self):
863867
# Reusable client context
864868
client_context = ClientContext()
865869

870+
# Global event loop for async tests.
871+
LOOP = None
866872

867-
def reset_client_context():
868-
if _IS_SYNC:
869-
# sync tests don't need to reset a client context
870-
return
871-
elif client_context.client is not None:
872-
client_context.client.close()
873-
client_context.client = None
874-
client_context._init_client()
873+
874+
def get_loop() -> asyncio.AbstractEventLoop:
875+
"""Get the test suite's global event loop."""
876+
global LOOP
877+
if LOOP is None:
878+
try:
879+
LOOP = asyncio.get_running_loop()
880+
except RuntimeError:
881+
# no running event loop, fallback to get_event_loop.
882+
try:
883+
# Ignore DeprecationWarning: There is no current event loop
884+
with warnings.catch_warnings():
885+
warnings.simplefilter("ignore", DeprecationWarning)
886+
LOOP = asyncio.get_event_loop()
887+
except RuntimeError:
888+
LOOP = asyncio.new_event_loop()
889+
asyncio.set_event_loop(LOOP)
890+
return LOOP
875891

876892

877893
class PyMongoTestCase(unittest.TestCase):
894+
if not _IS_SYNC:
895+
# An async TestCase that uses a single event loop for all tests.
896+
# Inspired by TestCase.
897+
def setUp(self):
898+
pass
899+
900+
def tearDown(self):
901+
pass
902+
903+
def addCleanup(self, func, /, *args, **kwargs):
904+
self.addCleanup(*(func, *args), **kwargs)
905+
906+
def _callSetUp(self):
907+
self.setUp()
908+
self._callAsync(self.setUp)
909+
910+
def _callTestMethod(self, method):
911+
self._callMaybeAsync(method)
912+
913+
def _callTearDown(self):
914+
self._callAsync(self.tearDown)
915+
self.tearDown()
916+
917+
def _callCleanup(self, function, *args, **kwargs):
918+
self._callMaybeAsync(function, *args, **kwargs)
919+
920+
def _callAsync(self, func, /, *args, **kwargs):
921+
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
922+
return get_loop().run_until_complete(func(*args, **kwargs))
923+
924+
def _callMaybeAsync(self, func, /, *args, **kwargs):
925+
if inspect.iscoroutinefunction(func):
926+
return get_loop().run_until_complete(func(*args, **kwargs))
927+
else:
928+
return func(*args, **kwargs)
929+
878930
def assertEqualCommand(self, expected, actual, msg=None):
879931
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
880932

@@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase):
11361188

11371189
@client_context.require_connection
11381190
def setUp(self) -> None:
1139-
if not _IS_SYNC:
1140-
reset_client_context()
11411191
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11421192
raise SkipTest("this test does not support load balancers")
11431193
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@@ -1186,6 +1236,9 @@ def tearDown(self) -> None:
11861236

11871237

11881238
def setup():
1239+
if not _IS_SYNC:
1240+
# Set up the event loop.
1241+
get_loop()
11891242
client_context.init()
11901243
warnings.resetwarnings()
11911244
warnings.simplefilter("always")

test/asynchronous/__init__.py

+87-34
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import asyncio
1919
import gc
20+
import inspect
2021
import logging
2122
import multiprocessing
2223
import os
@@ -30,28 +31,6 @@
3031
import unittest
3132
import warnings
3233
from asyncio import iscoroutinefunction
33-
from test.helpers import (
34-
COMPRESSORS,
35-
IS_SRV,
36-
MONGODB_API_VERSION,
37-
MULTI_MONGOS_LB_URI,
38-
TEST_LOADBALANCER,
39-
TEST_SERVERLESS,
40-
TLS_OPTIONS,
41-
SystemCertsPatcher,
42-
client_knobs,
43-
db_pwd,
44-
db_user,
45-
global_knobs,
46-
host,
47-
is_server_resolvable,
48-
port,
49-
print_running_topology,
50-
print_thread_stacks,
51-
print_thread_tracebacks,
52-
sanitize_cmd,
53-
sanitize_reply,
54-
)
5534

5635
from pymongo.uri_parser import parse_uri
5736

@@ -63,7 +42,6 @@
6342
HAVE_IPADDRESS = False
6443
from contextlib import asynccontextmanager, contextmanager
6544
from functools import partial, wraps
66-
from test.version import Version
6745
from typing import Any, Callable, Dict, Generator, overload
6846
from unittest import SkipTest
6947
from urllib.parse import quote_plus
@@ -78,6 +56,32 @@
7856
from pymongo.server_api import ServerApi
7957
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
8058

59+
sys.path[0:0] = [""]
60+
61+
from test.helpers import (
62+
COMPRESSORS,
63+
IS_SRV,
64+
MONGODB_API_VERSION,
65+
MULTI_MONGOS_LB_URI,
66+
TEST_LOADBALANCER,
67+
TEST_SERVERLESS,
68+
TLS_OPTIONS,
69+
SystemCertsPatcher,
70+
client_knobs,
71+
db_pwd,
72+
db_user,
73+
global_knobs,
74+
host,
75+
is_server_resolvable,
76+
port,
77+
print_running_topology,
78+
print_thread_stacks,
79+
print_thread_tracebacks,
80+
sanitize_cmd,
81+
sanitize_reply,
82+
)
83+
from test.version import Version
84+
8185
_IS_SYNC = False
8286

8387

@@ -865,18 +869,66 @@ async def max_message_size_bytes(self):
865869
# Reusable client context
866870
async_client_context = AsyncClientContext()
867871

872+
# Global event loop for async tests.
873+
LOOP = None
874+
875+
876+
def get_loop() -> asyncio.AbstractEventLoop:
877+
"""Get the test suite's global event loop."""
878+
global LOOP
879+
if LOOP is None:
880+
try:
881+
LOOP = asyncio.get_running_loop()
882+
except RuntimeError:
883+
# no running event loop, fallback to get_event_loop.
884+
try:
885+
# Ignore DeprecationWarning: There is no current event loop
886+
with warnings.catch_warnings():
887+
warnings.simplefilter("ignore", DeprecationWarning)
888+
LOOP = asyncio.get_event_loop()
889+
except RuntimeError:
890+
LOOP = asyncio.new_event_loop()
891+
asyncio.set_event_loop(LOOP)
892+
return LOOP
893+
894+
895+
class AsyncPyMongoTestCase(unittest.TestCase):
896+
if not _IS_SYNC:
897+
# An async TestCase that uses a single event loop for all tests.
898+
# Inspired by IsolatedAsyncioTestCase.
899+
async def asyncSetUp(self):
900+
pass
868901

869-
async def reset_client_context():
870-
if _IS_SYNC:
871-
# sync tests don't need to reset a client context
872-
return
873-
elif async_client_context.client is not None:
874-
await async_client_context.client.close()
875-
async_client_context.client = None
876-
await async_client_context._init_client()
902+
async def asyncTearDown(self):
903+
pass
877904

905+
def addAsyncCleanup(self, func, /, *args, **kwargs):
906+
self.addCleanup(*(func, *args), **kwargs)
907+
908+
def _callSetUp(self):
909+
self.setUp()
910+
self._callAsync(self.asyncSetUp)
911+
912+
def _callTestMethod(self, method):
913+
self._callMaybeAsync(method)
914+
915+
def _callTearDown(self):
916+
self._callAsync(self.asyncTearDown)
917+
self.tearDown()
918+
919+
def _callCleanup(self, function, *args, **kwargs):
920+
self._callMaybeAsync(function, *args, **kwargs)
921+
922+
def _callAsync(self, func, /, *args, **kwargs):
923+
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
924+
return get_loop().run_until_complete(func(*args, **kwargs))
925+
926+
def _callMaybeAsync(self, func, /, *args, **kwargs):
927+
if inspect.iscoroutinefunction(func):
928+
return get_loop().run_until_complete(func(*args, **kwargs))
929+
else:
930+
return func(*args, **kwargs)
878931

879-
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
880932
def assertEqualCommand(self, expected, actual, msg=None):
881933
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
882934

@@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
11541206

11551207
@async_client_context.require_connection
11561208
async def asyncSetUp(self) -> None:
1157-
if not _IS_SYNC:
1158-
await reset_client_context()
11591209
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11601210
raise SkipTest("this test does not support load balancers")
11611211
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@@ -1204,6 +1254,9 @@ async def asyncTearDown(self) -> None:
12041254

12051255

12061256
async def async_setup():
1257+
if not _IS_SYNC:
1258+
# Set up the event loop.
1259+
get_loop()
12071260
await async_client_context.init()
12081261
warnings.resetwarnings()
12091262
warnings.simplefilter("always")

0 commit comments

Comments
 (0)