Skip to content

PYTHON-5071 Use one event loop for all asyncio tests #2086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 30, 2025
121 changes: 87 additions & 34 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
Expand All @@ -30,28 +31,6 @@
import unittest
import warnings
from asyncio import iscoroutinefunction
from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)

from pymongo.uri_parser import parse_uri

Expand All @@ -63,7 +42,6 @@
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
Expand All @@ -78,6 +56,32 @@
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient

sys.path[0:0] = [""]

from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)
from test.version import Version

_IS_SYNC = True


Expand Down Expand Up @@ -864,17 +868,44 @@ def max_message_size_bytes(self):
client_context = ClientContext()


def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()
class PyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by TestCase.
def setUp(self):
pass

def tearDown(self):
pass

# See TestCase.addCleanup.
def addCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)

def _callSetUp(self):
self.setUp()
self._callAsync(self.setUp)

def _callTestMethod(self, method):
self._callMaybeAsync(method)

def _callTearDown(self):
self._callAsync(self.tearDown)
self.tearDown()

def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)

def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))

def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)

class PyMongoTestCase(unittest.TestCase):
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)

Expand Down Expand Up @@ -1136,8 +1167,6 @@ class IntegrationTest(PyMongoTestCase):

@client_context.require_connection
def setUp(self) -> None:
if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
Expand Down Expand Up @@ -1185,7 +1214,31 @@ def tearDown(self) -> None:
super().tearDown()


LOOP = None


def get_loop() -> asyncio.AbstractEventLoop:
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP


def setup():
if not _IS_SYNC:
global LOOP
LOOP = asyncio.get_running_loop()
client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")
Expand Down
121 changes: 87 additions & 34 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
Expand All @@ -30,28 +31,6 @@
import unittest
import warnings
from asyncio import iscoroutinefunction
from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)

from pymongo.uri_parser import parse_uri

Expand All @@ -63,7 +42,6 @@
HAVE_IPADDRESS = False
from contextlib import asynccontextmanager, contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
Expand All @@ -78,6 +56,32 @@
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]

sys.path[0:0] = [""]

from test.helpers import (
COMPRESSORS,
IS_SRV,
MONGODB_API_VERSION,
MULTI_MONGOS_LB_URI,
TEST_LOADBALANCER,
TEST_SERVERLESS,
TLS_OPTIONS,
SystemCertsPatcher,
client_knobs,
db_pwd,
db_user,
global_knobs,
host,
is_server_resolvable,
port,
print_running_topology,
print_thread_stacks,
print_thread_tracebacks,
sanitize_cmd,
sanitize_reply,
)
from test.version import Version

_IS_SYNC = False


Expand Down Expand Up @@ -866,17 +870,44 @@ async def max_message_size_bytes(self):
async_client_context = AsyncClientContext()


async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()
class AsyncPyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by IsolatedAsyncioTestCase.
async def asyncSetUp(self):
pass

async def asyncTearDown(self):
pass

# See IsolatedAsyncioTestCase.addAsyncCleanup.
def addAsyncCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)

def _callSetUp(self):
self.setUp()
self._callAsync(self.asyncSetUp)

def _callTestMethod(self, method):
self._callMaybeAsync(method)

def _callTearDown(self):
self._callAsync(self.asyncTearDown)
self.tearDown()

def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)

def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))

def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)

class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)

Expand Down Expand Up @@ -1154,8 +1185,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):

@async_client_context.require_connection
async def asyncSetUp(self) -> None:
if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
Expand Down Expand Up @@ -1203,7 +1232,31 @@ async def asyncTearDown(self) -> None:
await super().asyncTearDown()


LOOP = None


def get_loop() -> asyncio.AbstractEventLoop:
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP


async def async_setup():
if not _IS_SYNC:
global LOOP
LOOP = asyncio.get_running_loop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be get_loop instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is the correct method to call within an async function.

Copy link
Contributor

@NoahStapp NoahStapp Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I think we still have a loop mismatch here, although I might be misunderstanding: this will use the loop pytest creates for our package-scoped test_setup_and_teardown fixture. Then, our first async testcase will create its own separate loop that will be used for each test. So our async_client_context will be on one pytest-managed loop, and all of our test cases will be on one self-managed loop created by get_loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our pyproject.toml also sets the default loop_scope to be session instead of package FYI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either one is fine because get_loop calls get_running_loop first but I updated to use get_loop to be more consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempting to use "package" in pyproject.toml results in this error:

____________________________________________________ ERROR at setup of TestCommon.test_mongo_client ____________________________________________________
file /Users/shane/git/mongo-python-driver/test/asynchronous/test_common.py, line 135
      async def test_mongo_client(self):
          pair = await async_client_context.pair
          m = await self.async_rs_or_single_client(w=0)
          coll = m.pymongo_test.write_concern_test
          await coll.drop()
          doc = {"_id": ObjectId()}
          await coll.insert_one(doc)
          self.assertTrue(await coll.insert_one(doc))
          coll = coll.with_options(write_concern=WriteConcern(w=1))
          with self.assertRaises(OperationFailure):
              await coll.insert_one(doc)

          m = await self.async_rs_or_single_client()
          coll = m.pymongo_test.write_concern_test
          new_coll = coll.with_options(write_concern=WriteConcern(w=0))
          self.assertTrue(await new_coll.insert_one(doc))
          with self.assertRaises(OperationFailure):
              await coll.insert_one(doc)

          m = await self.async_rs_or_single_client(
              f"mongodb://{pair}/", replicaSet=async_client_context.replica_set_name
          )

          coll = m.pymongo_test.write_concern_test
          with self.assertRaises(OperationFailure):
              await coll.insert_one(doc)
          m = await self.async_rs_or_single_client(
              f"mongodb://{pair}/?w=0", replicaSet=async_client_context.replica_set_name
          )

          coll = m.pymongo_test.write_concern_test
          await coll.insert_one(doc)

          # Equality tests
          direct = await connected(await self.async_single_client(w=0))
          direct2 = await connected(
              await self.async_single_client(f"mongodb://{pair}/?w=0", **self.credentials)
          )
          self.assertEqual(direct, direct2)
          self.assertFalse(direct != direct2)
file /Users/shane/git/mongo-python-driver/test/asynchronous/conftest.py, line 25
  @pytest_asyncio.fixture(scope="package", autouse=True)
  async def test_setup_and_teardown():
      await async_setup()
      yield
      await async_teardown()
E       fixture 'test/asynchronous::<event_loop>' not found
>       available fixtures: _session_event_loop, _unittest_setUpClass_fixture_TestCommon, anyio_backend, anyio_backend_name, anyio_backend_options, cache, capfd, capfdbinary, caplog, capsys, capsysbinary, doctest_namespace, event_loop, event_loop_policy, monkeypatch, pytestconfig, record_property, record_testsuite_property, record_xml_attribute, recwarn, test/asynchronous/test_common.py::<event_loop>, test/asynchronous/test_common.py::TestCommon::<event_loop>, test_setup_and_teardown, tmp_path, tmp_path_factory, tmpdir, tmpdir_factory, unused_tcp_port, unused_tcp_port_factory, unused_udp_port, unused_udp_port_factory
>       use 'pytest --fixtures [testpath]' for help on them.

"session" seems to be the right thing for what we want there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened pytest-dev/pytest-asyncio#1052 about that exact error, it persists even if you explicitly mark a fixture as package.

With any pytest-asyncio scope, we'll still have two separate loops. That doesn't seem to be causing test failures, but it's something to be aware of for future debugging 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I subscribed to that issue. However I did verify that the current code only ever uses 1 loop for the whole async test suite. Printing the loop that pytest-asyncio creates shows this: event_loop_fixture_id='_session_event_loop' which means the asyncio_default_fixture_loop_scope="session" in pyproject is behaving as expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean that get_loop is picking up and using pytest's _session_event_loop? Or that that loop is separate from the loop the tests use that get_loop creates?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first one.

await async_client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")
Expand Down
Loading