1from __future__ import annotations
2
3import re
4from collections.abc import Generator
5from typing import Any
6
7import pytest
8
9from plain.postgres.otel import suppress_db_tracing
10from plain.signals import request_finished, request_started
11
12from .. import transaction
13from ..connection import DatabaseConnection
14from ..db import close_old_connections, get_connection
15from .utils import (
16 setup_database,
17 teardown_database,
18)
19
20
21@pytest.fixture(autouse=True)
22def _db_disabled() -> Generator[None]:
23 """
24 Every test should use this fixture by default to prevent
25 access to the normal database.
26 """
27
28 def cursor_disabled(self: Any) -> None:
29 pytest.fail("Database access not allowed without the `db` fixture") # type: ignore[invalid-argument-type]
30
31 # Save original cursor method and replace with disabled version
32 setattr(DatabaseConnection, "_enabled_cursor", DatabaseConnection.cursor)
33 DatabaseConnection.cursor = cursor_disabled # type: ignore[assignment]
34
35 yield
36
37 # Restore original cursor method
38 DatabaseConnection.cursor = getattr(DatabaseConnection, "_enabled_cursor")
39
40
41@pytest.fixture(scope="session")
42def setup_db(request: Any) -> Generator[None]:
43 """
44 This fixture is called automatically by `db`,
45 so a test database will only be setup if the `db` fixture is used.
46 """
47 verbosity = request.config.option.verbose
48
49 # Set up the test db across the entire session
50 _old_db_name = setup_database(verbosity=verbosity)
51
52 # Keep connections open during request client / testing
53 request_started.disconnect(close_old_connections)
54 request_finished.disconnect(close_old_connections)
55
56 yield
57
58 # Put the signals back...
59 request_started.connect(close_old_connections)
60 request_finished.connect(close_old_connections)
61
62 # When the test session is done, tear down the test db
63 teardown_database(_old_db_name, verbosity=verbosity)
64
65
66@pytest.fixture
67def db(setup_db: Any, request: Any) -> Generator[None]:
68 if "isolated_db" in request.fixturenames:
69 pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together") # type: ignore[invalid-argument-type]
70
71 # Set .cursor() back to the original implementation to unblock it
72 DatabaseConnection.cursor = getattr(DatabaseConnection, "_enabled_cursor")
73
74 with suppress_db_tracing():
75 atomic = transaction.atomic()
76 atomic._from_testcase = True
77 atomic.__enter__()
78
79 yield
80
81 with suppress_db_tracing():
82 conn = get_connection()
83 # PostgreSQL can defer constraint checks
84 if not conn.needs_rollback and conn.is_usable():
85 conn.check_constraints()
86
87 conn.set_rollback(True)
88 atomic.__exit__(None, None, None)
89
90 conn.close()
91
92
93@pytest.fixture
94def isolated_db(request: Any) -> Generator[None]:
95 """
96 Create and destroy a unique test database for each test, using a prefix
97 derived from the test function name to ensure isolation from the default
98 test database.
99 """
100 if "db" in request.fixturenames:
101 pytest.fail("The 'db' and 'isolated_db' fixtures cannot be used together") # type: ignore[invalid-argument-type]
102 # Set .cursor() back to the original implementation to unblock it
103 DatabaseConnection.cursor = getattr(DatabaseConnection, "_enabled_cursor")
104
105 verbosity = 1
106
107 # Derive a safe prefix from the test function name
108 raw_name = request.node.name
109 prefix = re.sub(r"[^0-9A-Za-z_]+", "_", raw_name)
110
111 # Set up a fresh test database for this test, using the prefix
112 _old_db_name = setup_database(verbosity=verbosity, prefix=prefix)
113
114 yield
115
116 # Tear down the test database created for this test
117 teardown_database(_old_db_name, verbosity=verbosity)