1from __future__ import annotations
2
3import re
4from collections.abc import Generator
5from typing import Any
6
7import pytest
8
9from plain.models.otel import suppress_db_tracing
10from plain.signals import request_finished, request_started
11
12from .. import transaction
13from ..backends.base.base import BaseDatabaseWrapper
14from ..db import close_old_connections, db_connection
15from .utils import (
16 setup_database,
17 teardown_database,
18)
19
20
21@pytest.fixture(autouse=True)
22def _db_disabled() -> Generator[None, None, None]:
23 """
24 Every test should use this fixture by default to prevent
25 access to the normal database.
26 """
27
28 def cursor_disabled(self: Any) -> None:
29 pytest.fail("Database access not allowed without the `db` fixture")
30
31 # Save original cursor method and replace with disabled version
32 setattr(BaseDatabaseWrapper, "_enabled_cursor", BaseDatabaseWrapper.cursor)
33 BaseDatabaseWrapper.cursor = cursor_disabled # type: ignore[method-assign]
34
35 yield
36
37 # Restore original cursor method
38 BaseDatabaseWrapper.cursor = getattr(BaseDatabaseWrapper, "_enabled_cursor") # type: ignore[method-assign]
39
40
41@pytest.fixture(scope="session")
42def setup_db(request: Any) -> Generator[None, None, None]:
43 """
44 This fixture is called automatically by `db`,
45 so a test database will only be setup if the `db` fixture is used.
46 """
47 verbosity = request.config.option.verbose
48
49 # Set up the test db across the entire session
50 _old_db_name = setup_database(verbosity=verbosity)
51
52 # Keep connections open during request client / testing
53 request_started.disconnect(close_old_connections)
54 request_finished.disconnect(close_old_connections)
55
56 yield
57
58 # Put the signals back...
59 request_started.connect(close_old_connections)
60 request_finished.connect(close_old_connections)
61
62 # When the test session is done, tear down the test db
63 teardown_database(_old_db_name, verbosity=verbosity)
64
65
66@pytest.fixture
67def db(setup_db: Any, request: Any) -> Generator[None, None, None]:
68 if "isolated_db" in request.fixturenames:
69 pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
70
71 # Set .cursor() back to the original implementation to unblock it
72 BaseDatabaseWrapper.cursor = getattr(BaseDatabaseWrapper, "_enabled_cursor") # type: ignore[method-assign]
73
74 if not db_connection.features.supports_transactions:
75 pytest.fail("Database does not support transactions")
76
77 with suppress_db_tracing():
78 atomic = transaction.atomic()
79 atomic._from_testcase = True
80 atomic.__enter__()
81
82 yield
83
84 with suppress_db_tracing():
85 if (
86 db_connection.features.can_defer_constraint_checks
87 and not db_connection.needs_rollback
88 and db_connection.is_usable()
89 ):
90 db_connection.check_constraints()
91
92 db_connection.set_rollback(True)
93 atomic.__exit__(None, None, None)
94
95 db_connection.close()
96
97
98@pytest.fixture
99def isolated_db(request: Any) -> Generator[None, None, None]:
100 """
101 Create and destroy a unique test database for each test, using a prefix
102 derived from the test function name to ensure isolation from the default
103 test database.
104 """
105 if "db" in request.fixturenames:
106 pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together")
107 # Set .cursor() back to the original implementation to unblock it
108 BaseDatabaseWrapper.cursor = getattr(BaseDatabaseWrapper, "_enabled_cursor") # type: ignore[method-assign]
109
110 verbosity = 1
111
112 # Derive a safe prefix from the test function name
113 raw_name = request.node.name
114 prefix = re.sub(r"[^0-9A-Za-z_]+", "_", raw_name)
115
116 # Set up a fresh test database for this test, using the prefix
117 _old_db_name = setup_database(verbosity=verbosity, prefix=prefix)
118
119 yield
120
121 # Tear down the test database created for this test
122 teardown_database(_old_db_name, verbosity=verbosity)