1from __future__ import annotations
  2
  3import os
  4import sys
  5from collections.abc import Generator
  6from contextlib import contextmanager
  7
  8from psycopg import errors
  9
 10from plain.postgres.connection import DatabaseConnection
 11from plain.postgres.connections import _db_conn
 12from plain.postgres.database_url import parse_database_url, replace_database_name
 13from plain.postgres.dialect import MAX_NAME_LENGTH, quote_name
 14from plain.postgres.migrations.executor import MigrationExecutor
 15from plain.postgres.utils import names_digest
 16from plain.runtime import settings
 17
 18TEST_DATABASE_PREFIX = "test_"
 19
 20
 21def _log(msg: str) -> None:
 22    sys.stderr.write(msg + os.linesep)
 23
 24
 25def _compute_test_db_name(base_name: str, prefix: str = "") -> str:
 26    """Compute the test database name from the runtime DB name."""
 27    if prefix:
 28        name = f"{prefix}_{base_name}"
 29        if len(name) > MAX_NAME_LENGTH:
 30            hash_suffix = names_digest(name, length=8)
 31            name = name[: MAX_NAME_LENGTH - 9] + "_" + hash_suffix
 32        return name
 33    return TEST_DATABASE_PREFIX + base_name
 34
 35
 36def _create_database_on_server(
 37    conn: DatabaseConnection, *, name: str, verbosity: int, autoclobber: bool
 38) -> None:
 39    """CREATE DATABASE via the maintenance connection, with autoclobber fallback."""
 40    quoted = quote_name(name)
 41    with conn._maintenance_cursor() as cursor:
 42        try:
 43            cursor.execute(f"CREATE DATABASE {quoted}")
 44            return
 45        except Exception as e:
 46            cause = e.__cause__
 47            if not (cause and isinstance(cause, errors.DuplicateDatabase)):
 48                _log(f"Got an error creating the test database: {e}")
 49                sys.exit(2)
 50            # Database already exists — fall through to autoclobber handling.
 51            _log(f"Got an error creating the test database: {e}")
 52
 53        if not autoclobber:
 54            confirm = input(
 55                "Type 'yes' if you would like to try deleting the test "
 56                f"database '{name}', or 'no' to cancel: "
 57            )
 58            if confirm != "yes":
 59                _log("Tests cancelled.")
 60                sys.exit(1)
 61
 62        try:
 63            if verbosity >= 1:
 64                _log(f"Destroying old test database '{name}'...")
 65            cursor.execute(f"DROP DATABASE {quoted}")
 66            cursor.execute(f"CREATE DATABASE {quoted}")
 67        except Exception as e:
 68            _log(f"Got an error recreating the test database: {e}")
 69            sys.exit(2)
 70
 71
 72def _drop_database_on_server(conn: DatabaseConnection, name: str) -> None:
 73    with conn._maintenance_cursor() as cursor:
 74        cursor.execute(f"DROP DATABASE {quote_name(name)}")
 75
 76
 77@contextmanager
 78def use_test_database(*, verbosity: int = 1, prefix: str = "") -> Generator[str]:
 79    """Create a test database, install it as the active connection, drop on exit.
 80
 81    Inside the block, `get_connection()` returns a connection opened against
 82    the test database. Migrations and convergence run directly via their
 83    Python APIs (`MigrationExecutor`, `plan_convergence`) — not via the CLI
 84    commands — so no `POSTGRES_MANAGEMENT_URL` swap happens during setup.
 85
 86    Yields the test database name.
 87    """
 88    from plain.postgres.convergence import execute_plan, plan_convergence
 89
 90    runtime_url = str(settings.POSTGRES_URL)
 91    if not runtime_url:
 92        raise ValueError("POSTGRES_URL must be set before creating a test database.")
 93
 94    base_name = parse_database_url(runtime_url).get("DATABASE")
 95    if not base_name:
 96        raise ValueError("POSTGRES_URL must include a database name")
 97
 98    test_db_name = _compute_test_db_name(base_name, prefix)
 99
100    if verbosity >= 1:
101        _log(f"Creating test database '{test_db_name}'...")
102
103    test_url = replace_database_name(runtime_url, test_db_name)
104    test_conn = DatabaseConnection.from_url(test_url)
105
106    # Create the test database on the server via a sibling maintenance
107    # connection. `_maintenance_cursor` builds its own `postgres`-targeted
108    # connection from settings_dict, so test_conn itself is not opened yet.
109    _create_database_on_server(
110        test_conn, name=test_db_name, verbosity=verbosity, autoclobber=True
111    )
112
113    conn_token = _db_conn.set(test_conn)
114    try:
115        executor = MigrationExecutor(test_conn)
116        targets = list(executor.loader.graph.leaf_nodes())
117        executor.migrate(targets)
118
119        plan = plan_convergence()
120        result = execute_plan(plan.executable())
121        if not result.ok:
122            failed = [r for r in result.results if not r.ok]
123            raise RuntimeError(
124                f"Convergence failed during test DB setup: {failed[0].item.describe()}{failed[0].error}"
125            )
126        # A fresh DB from migrations shouldn't have undeclared objects or
127        # changed definitions — safety net so test setup follows sync policy.
128        if plan.blocked:
129            problem = plan.blocked[0]
130            raise RuntimeError(
131                f"Convergence blocked during test DB setup: {problem.describe()}"
132            )
133
134        test_conn.ensure_connection()
135
136        yield test_db_name
137    finally:
138        _db_conn.reset(conn_token)
139
140        try:
141            test_conn.close()
142        except Exception:
143            pass
144
145        if verbosity >= 1:
146            _log(f"Destroying test database '{test_db_name}'...")
147        try:
148            _drop_database_on_server(test_conn, test_db_name)
149        except Exception as e:
150            _log(f"Got an error destroying the test database: {e}")