|
17 | 17 |
|
18 | 18 | import asyncio
|
19 | 19 | import gc
|
| 20 | +import inspect |
20 | 21 | import logging
|
21 | 22 | import multiprocessing
|
22 | 23 | import os
|
|
30 | 31 | import unittest
|
31 | 32 | import warnings
|
32 | 33 | from asyncio import iscoroutinefunction
|
33 |
| -from test.helpers import ( |
34 |
| - COMPRESSORS, |
35 |
| - IS_SRV, |
36 |
| - MONGODB_API_VERSION, |
37 |
| - MULTI_MONGOS_LB_URI, |
38 |
| - TEST_LOADBALANCER, |
39 |
| - TEST_SERVERLESS, |
40 |
| - TLS_OPTIONS, |
41 |
| - SystemCertsPatcher, |
42 |
| - client_knobs, |
43 |
| - db_pwd, |
44 |
| - db_user, |
45 |
| - global_knobs, |
46 |
| - host, |
47 |
| - is_server_resolvable, |
48 |
| - port, |
49 |
| - print_running_topology, |
50 |
| - print_thread_stacks, |
51 |
| - print_thread_tracebacks, |
52 |
| - sanitize_cmd, |
53 |
| - sanitize_reply, |
54 |
| -) |
55 | 34 |
|
56 | 35 | from pymongo.uri_parser import parse_uri
|
57 | 36 |
|
|
63 | 42 | HAVE_IPADDRESS = False
|
64 | 43 | from contextlib import asynccontextmanager, contextmanager
|
65 | 44 | from functools import partial, wraps
|
66 |
| -from test.version import Version |
67 | 45 | from typing import Any, Callable, Dict, Generator, overload
|
68 | 46 | from unittest import SkipTest
|
69 | 47 | from urllib.parse import quote_plus
|
|
78 | 56 | from pymongo.server_api import ServerApi
|
79 | 57 | from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
80 | 58 |
|
| 59 | +sys.path[0:0] = [""] |
| 60 | + |
| 61 | +from test.helpers import ( |
| 62 | + COMPRESSORS, |
| 63 | + IS_SRV, |
| 64 | + MONGODB_API_VERSION, |
| 65 | + MULTI_MONGOS_LB_URI, |
| 66 | + TEST_LOADBALANCER, |
| 67 | + TEST_SERVERLESS, |
| 68 | + TLS_OPTIONS, |
| 69 | + SystemCertsPatcher, |
| 70 | + client_knobs, |
| 71 | + db_pwd, |
| 72 | + db_user, |
| 73 | + global_knobs, |
| 74 | + host, |
| 75 | + is_server_resolvable, |
| 76 | + port, |
| 77 | + print_running_topology, |
| 78 | + print_thread_stacks, |
| 79 | + print_thread_tracebacks, |
| 80 | + sanitize_cmd, |
| 81 | + sanitize_reply, |
| 82 | +) |
| 83 | +from test.version import Version |
| 84 | + |
81 | 85 | _IS_SYNC = False
|
82 | 86 |
|
83 | 87 |
|
@@ -865,18 +869,66 @@ async def max_message_size_bytes(self):
|
865 | 869 | # Reusable client context
|
866 | 870 | async_client_context = AsyncClientContext()
|
867 | 871 |
|
| 872 | +# Global event loop for async tests. |
| 873 | +LOOP = None |
| 874 | + |
| 875 | + |
| 876 | +def get_loop() -> asyncio.AbstractEventLoop: |
| 877 | + """Get the test suite's global event loop.""" |
| 878 | + global LOOP |
| 879 | + if LOOP is None: |
| 880 | + try: |
| 881 | + LOOP = asyncio.get_running_loop() |
| 882 | + except RuntimeError: |
| 883 | + # no running event loop, fallback to get_event_loop. |
| 884 | + try: |
| 885 | + # Ignore DeprecationWarning: There is no current event loop |
| 886 | + with warnings.catch_warnings(): |
| 887 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 888 | + LOOP = asyncio.get_event_loop() |
| 889 | + except RuntimeError: |
| 890 | + LOOP = asyncio.new_event_loop() |
| 891 | + asyncio.set_event_loop(LOOP) |
| 892 | + return LOOP |
| 893 | + |
| 894 | + |
| 895 | +class AsyncPyMongoTestCase(unittest.TestCase): |
| 896 | + if not _IS_SYNC: |
| 897 | + # An async TestCase that uses a single event loop for all tests. |
| 898 | + # Inspired by IsolatedAsyncioTestCase. |
| 899 | + async def asyncSetUp(self): |
| 900 | + pass |
868 | 901 |
|
869 |
| -async def reset_client_context(): |
870 |
| - if _IS_SYNC: |
871 |
| - # sync tests don't need to reset a client context |
872 |
| - return |
873 |
| - elif async_client_context.client is not None: |
874 |
| - await async_client_context.client.close() |
875 |
| - async_client_context.client = None |
876 |
| - await async_client_context._init_client() |
| 902 | + async def asyncTearDown(self): |
| 903 | + pass |
877 | 904 |
|
| 905 | + def addAsyncCleanup(self, func, /, *args, **kwargs): |
| 906 | + self.addCleanup(*(func, *args), **kwargs) |
| 907 | + |
| 908 | + def _callSetUp(self): |
| 909 | + self.setUp() |
| 910 | + self._callAsync(self.asyncSetUp) |
| 911 | + |
| 912 | + def _callTestMethod(self, method): |
| 913 | + self._callMaybeAsync(method) |
| 914 | + |
| 915 | + def _callTearDown(self): |
| 916 | + self._callAsync(self.asyncTearDown) |
| 917 | + self.tearDown() |
| 918 | + |
| 919 | + def _callCleanup(self, function, *args, **kwargs): |
| 920 | + self._callMaybeAsync(function, *args, **kwargs) |
| 921 | + |
| 922 | + def _callAsync(self, func, /, *args, **kwargs): |
| 923 | + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" |
| 924 | + return get_loop().run_until_complete(func(*args, **kwargs)) |
| 925 | + |
| 926 | + def _callMaybeAsync(self, func, /, *args, **kwargs): |
| 927 | + if inspect.iscoroutinefunction(func): |
| 928 | + return get_loop().run_until_complete(func(*args, **kwargs)) |
| 929 | + else: |
| 930 | + return func(*args, **kwargs) |
878 | 931 |
|
879 |
| -class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): |
880 | 932 | def assertEqualCommand(self, expected, actual, msg=None):
|
881 | 933 | self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
882 | 934 |
|
@@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
1154 | 1206 |
|
1155 | 1207 | @async_client_context.require_connection
|
1156 | 1208 | async def asyncSetUp(self) -> None:
|
1157 |
| - if not _IS_SYNC: |
1158 |
| - await reset_client_context() |
1159 | 1209 | if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
1160 | 1210 | raise SkipTest("this test does not support load balancers")
|
1161 | 1211 | if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
@@ -1204,6 +1254,9 @@ async def asyncTearDown(self) -> None:
|
1204 | 1254 |
|
1205 | 1255 |
|
1206 | 1256 | async def async_setup():
|
| 1257 | + if not _IS_SYNC: |
| 1258 | + # Set up the event loop. |
| 1259 | + get_loop() |
1207 | 1260 | await async_client_context.init()
|
1208 | 1261 | warnings.resetwarnings()
|
1209 | 1262 | warnings.simplefilter("always")
|
|
0 commit comments