Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import re
  4from collections.abc import Generator
  5from typing import Any
  6
  7import pytest
  8
  9from plain.models.otel import suppress_db_tracing
 10from plain.signals import request_finished, request_started
 11
 12from .. import transaction
 13from ..backends.base.base import BaseDatabaseWrapper
 14from ..db import close_old_connections, db_connection
 15from .utils import (
 16    setup_database,
 17    teardown_database,
 18)
 19
 20
 21@pytest.fixture(autouse=True)
 22def _db_disabled() -> Generator[None, None, None]:
 23    """
 24    Every test should use this fixture by default to prevent
 25    access to the normal database.
 26    """
 27
 28    def cursor_disabled(self: Any) -> None:
 29        pytest.fail("Database access not allowed without the `db` fixture")
 30
 31    # Save original cursor method and replace with disabled version
 32    setattr(BaseDatabaseWrapper, "_enabled_cursor", BaseDatabaseWrapper.cursor)
 33    BaseDatabaseWrapper.cursor = cursor_disabled  # type: ignore[method-assign]
 34
 35    yield
 36
 37    # Restore original cursor method
 38    BaseDatabaseWrapper.cursor = getattr(BaseDatabaseWrapper, "_enabled_cursor")  # type: ignore[method-assign]
 39
 40
 41@pytest.fixture(scope="session")
 42def setup_db(request: Any) -> Generator[None, None, None]:
 43    """
 44    This fixture is called automatically by `db`,
 45    so a test database will only be setup if the `db` fixture is used.
 46    """
 47    verbosity = request.config.option.verbose
 48
 49    # Set up the test db across the entire session
 50    _old_db_name = setup_database(verbosity=verbosity)
 51
 52    # Keep connections open during request client / testing
 53    request_started.disconnect(close_old_connections)
 54    request_finished.disconnect(close_old_connections)
 55
 56    yield
 57
 58    # Put the signals back...
 59    request_started.connect(close_old_connections)
 60    request_finished.connect(close_old_connections)
 61
 62    # When the test session is done, tear down the test db
 63    teardown_database(_old_db_name, verbosity=verbosity)
 64
 65
 66@pytest.fixture
 67def db(setup_db: Any, request: Any) -> Generator[None, None, None]:
 68    if "isolated_db" in request.fixturenames:
 69        pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
 70
 71    # Set .cursor() back to the original implementation to unblock it
 72    BaseDatabaseWrapper.cursor = getattr(BaseDatabaseWrapper, "_enabled_cursor")  # type: ignore[method-assign]
 73
 74    if not db_connection.features.supports_transactions:
 75        pytest.fail("Database does not support transactions")
 76
 77    with suppress_db_tracing():
 78        atomic = transaction.atomic()
 79        atomic._from_testcase = True
 80        atomic.__enter__()
 81
 82    yield
 83
 84    with suppress_db_tracing():
 85        if (
 86            db_connection.features.can_defer_constraint_checks
 87            and not db_connection.needs_rollback
 88            and db_connection.is_usable()
 89        ):
 90            db_connection.check_constraints()
 91
 92        db_connection.set_rollback(True)
 93        atomic.__exit__(None, None, None)
 94
 95        db_connection.close()
 96
 97
 98@pytest.fixture
 99def isolated_db(request: Any) -> Generator[None, None, None]:
100    """
101    Create and destroy a unique test database for each test, using a prefix
102    derived from the test function name to ensure isolation from the default
103    test database.
104    """
105    if "db" in request.fixturenames:
106        pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
107    # Set .cursor() back to the original implementation to unblock it
108    BaseDatabaseWrapper.cursor = getattr(BaseDatabaseWrapper, "_enabled_cursor")  # type: ignore[method-assign]
109
110    verbosity = 1
111
112    # Derive a safe prefix from the test function name
113    raw_name = request.node.name
114    prefix = re.sub(r"[^0-9A-Za-z_]+", "_", raw_name)
115
116    # Set up a fresh test database for this test, using the prefix
117    _old_db_name = setup_database(verbosity=verbosity, prefix=prefix)
118
119    yield
120
121    # Tear down the test database created for this test
122    teardown_database(_old_db_name, verbosity=verbosity)