1from __future__ import annotations
2
3import os
4import sys
5from collections.abc import Generator
6from contextlib import contextmanager
7
8from psycopg import errors
9
10from plain.postgres.connection import DatabaseConnection
11from plain.postgres.connections import _db_conn
12from plain.postgres.database_url import parse_database_url, replace_database_name
13from plain.postgres.dialect import MAX_NAME_LENGTH, quote_name
14from plain.postgres.migrations.executor import MigrationExecutor
15from plain.postgres.utils import names_digest
16from plain.runtime import settings
17
18TEST_DATABASE_PREFIX = "test_"
19
20
21def _log(msg: str) -> None:
22 sys.stderr.write(msg + os.linesep)
23
24
25def _compute_test_db_name(base_name: str, prefix: str = "") -> str:
26 """Compute the test database name from the runtime DB name."""
27 if prefix:
28 name = f"{prefix}_{base_name}"
29 if len(name) > MAX_NAME_LENGTH:
30 hash_suffix = names_digest(name, length=8)
31 name = name[: MAX_NAME_LENGTH - 9] + "_" + hash_suffix
32 return name
33 return TEST_DATABASE_PREFIX + base_name
34
35
36def _create_database_on_server(
37 conn: DatabaseConnection, *, name: str, verbosity: int, autoclobber: bool
38) -> None:
39 """CREATE DATABASE via the maintenance connection, with autoclobber fallback."""
40 quoted = quote_name(name)
41 with conn._maintenance_cursor() as cursor:
42 try:
43 cursor.execute(f"CREATE DATABASE {quoted}")
44 return
45 except Exception as e:
46 cause = e.__cause__
47 if not (cause and isinstance(cause, errors.DuplicateDatabase)):
48 _log(f"Got an error creating the test database: {e}")
49 sys.exit(2)
50 # Database already exists — fall through to autoclobber handling.
51 _log(f"Got an error creating the test database: {e}")
52
53 if not autoclobber:
54 confirm = input(
55 "Type 'yes' if you would like to try deleting the test "
56 f"database '{name}', or 'no' to cancel: "
57 )
58 if confirm != "yes":
59 _log("Tests cancelled.")
60 sys.exit(1)
61
62 try:
63 if verbosity >= 1:
64 _log(f"Destroying old test database '{name}'...")
65 cursor.execute(f"DROP DATABASE {quoted}")
66 cursor.execute(f"CREATE DATABASE {quoted}")
67 except Exception as e:
68 _log(f"Got an error recreating the test database: {e}")
69 sys.exit(2)
70
71
72def _drop_database_on_server(conn: DatabaseConnection, name: str) -> None:
73 with conn._maintenance_cursor() as cursor:
74 cursor.execute(f"DROP DATABASE {quote_name(name)}")
75
76
77@contextmanager
78def use_test_database(*, verbosity: int = 1, prefix: str = "") -> Generator[str]:
79 """Create a test database, install it as the active connection, drop on exit.
80
81 Inside the block, `get_connection()` returns a connection opened against
82 the test database. Migrations and convergence run directly via their
83 Python APIs (`MigrationExecutor`, `plan_convergence`) — not via the CLI
84 commands — so no `POSTGRES_MANAGEMENT_URL` swap happens during setup.
85
86 Yields the test database name.
87 """
88 from plain.postgres.convergence import execute_plan, plan_convergence
89
90 runtime_url = str(settings.POSTGRES_URL)
91 if not runtime_url:
92 raise ValueError("POSTGRES_URL must be set before creating a test database.")
93
94 base_name = parse_database_url(runtime_url).get("DATABASE")
95 if not base_name:
96 raise ValueError("POSTGRES_URL must include a database name")
97
98 test_db_name = _compute_test_db_name(base_name, prefix)
99
100 if verbosity >= 1:
101 _log(f"Creating test database '{test_db_name}'...")
102
103 test_url = replace_database_name(runtime_url, test_db_name)
104 test_conn = DatabaseConnection.from_url(test_url)
105
106 # Create the test database on the server via a sibling maintenance
107 # connection. `_maintenance_cursor` builds its own `postgres`-targeted
108 # connection from settings_dict, so test_conn itself is not opened yet.
109 _create_database_on_server(
110 test_conn, name=test_db_name, verbosity=verbosity, autoclobber=True
111 )
112
113 conn_token = _db_conn.set(test_conn)
114 try:
115 executor = MigrationExecutor(test_conn)
116 targets = list(executor.loader.graph.leaf_nodes())
117 executor.migrate(targets)
118
119 plan = plan_convergence()
120 result = execute_plan(plan.executable())
121 if not result.ok:
122 failed = [r for r in result.results if not r.ok]
123 raise RuntimeError(
124 f"Convergence failed during test DB setup: {failed[0].item.describe()} — {failed[0].error}"
125 )
126 # A fresh DB from migrations shouldn't have undeclared objects or
127 # changed definitions — safety net so test setup follows sync policy.
128 if plan.blocked:
129 problem = plan.blocked[0]
130 raise RuntimeError(
131 f"Convergence blocked during test DB setup: {problem.describe()}"
132 )
133
134 test_conn.ensure_connection()
135
136 yield test_db_name
137 finally:
138 _db_conn.reset(conn_token)
139
140 try:
141 test_conn.close()
142 except Exception:
143 pass
144
145 if verbosity >= 1:
146 _log(f"Destroying test database '{test_db_name}'...")
147 try:
148 _drop_database_on_server(test_conn, test_db_name)
149 except Exception as e:
150 _log(f"Got an error destroying the test database: {e}")