1from __future__ import annotations
   2
   3import _thread
   4import datetime
   5import logging
   6import os
   7import signal
   8import subprocess
   9import sys
  10import threading
  11import time
  12import warnings
  13import zoneinfo
  14from collections import deque
  15from collections.abc import Generator, Sequence
  16from contextlib import contextmanager
  17from functools import cached_property, lru_cache
  18from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
  19
  20import psycopg as Database
  21from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
  22from psycopg import sql as psycopg_sql
  23from psycopg.abc import Buffer, PyFormat
  24from psycopg.postgres import types as pg_types
  25from psycopg.pq import Format
  26from psycopg.types.datetime import TimestamptzLoader
  27from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
  28from psycopg.types.string import TextLoader
  29
  30from plain.exceptions import ImproperlyConfigured
  31from plain.models.db import (
  32    DatabaseError,
  33    DatabaseErrorWrapper,
  34    NotSupportedError,
  35    db_connection,
  36)
  37from plain.models.db import DatabaseError as WrappedDatabaseError
  38from plain.models.indexes import Index
  39from plain.models.postgres import utils
  40from plain.models.postgres.schema import DatabaseSchemaEditor
  41from plain.models.postgres.sql import MAX_NAME_LENGTH, quote_name
  42from plain.models.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
  43from plain.models.postgres.utils import CursorWrapper, debug_transaction
  44from plain.models.transaction import TransactionManagementError
  45from plain.runtime import settings
  46
  47if TYPE_CHECKING:
  48    from psycopg import Connection as PsycopgConnection
  49
  50    from plain.models.connections import DatabaseConfig
  51    from plain.models.fields import Field
  52
  53RAN_DB_VERSION_CHECK = False
  54
  55logger = logging.getLogger("plain.models.postgres")
  56
  57# The prefix to put on the default database name when creating
  58# the test database.
  59TEST_DATABASE_PREFIX = "test_"
  60
  61
  62def get_migratable_models() -> Generator[Any, None, None]:
  63    """Return all models that should be included in migrations."""
  64    from plain.models import models_registry
  65    from plain.packages import packages_registry
  66
  67    return (
  68        model
  69        for package_config in packages_registry.get_package_configs()
  70        for model in models_registry.get_models(
  71            package_label=package_config.package_label
  72        )
  73    )
  74
  75
  76class TableInfo(NamedTuple):
  77    """Structure returned by DatabaseWrapper.get_table_list()."""
  78
  79    name: str
  80    type: str
  81    comment: str | None
  82
  83
  84# Type OIDs
  85TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
  86TSRANGE_OID = pg_types["tsrange"].oid
  87TSTZRANGE_OID = pg_types["tstzrange"].oid
  88
  89
  90class BaseTzLoader(TimestamptzLoader):
  91    """
  92    Load a PostgreSQL timestamptz using a specific timezone.
  93    The timezone can be None too, in which case it will be chopped.
  94    """
  95
  96    timezone: datetime.tzinfo | None = None
  97
  98    def load(self, data: Buffer) -> datetime.datetime:
  99        res = super().load(data)
 100        return res.replace(tzinfo=self.timezone)
 101
 102
 103def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
 104    class SpecificTzLoader(BaseTzLoader):
 105        timezone = tz
 106
 107    context.adapters.register_loader("timestamptz", SpecificTzLoader)
 108
 109
 110class PlainRangeDumper(RangeDumper):
 111    """A Range dumper customized for Plain."""
 112
 113    def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
 114        dumper = super().upgrade(obj, format)
 115        if dumper is not self and dumper.oid == TSRANGE_OID:
 116            dumper.oid = TSTZRANGE_OID
 117        return dumper
 118
 119
 120@lru_cache
 121def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
 122    ctx = adapt.AdaptersMap(adapters)
 123    # No-op JSON loader to avoid psycopg3 round trips
 124    ctx.register_loader("jsonb", TextLoader)
 125    # Treat inet/cidr as text
 126    ctx.register_loader("inet", TextLoader)
 127    ctx.register_loader("cidr", TextLoader)
 128    ctx.register_dumper(Range, PlainRangeDumper)
 129    register_tzloader(timezone, ctx)
 130    return ctx
 131
 132
 133def _psql_settings_to_cmd_args_env(
 134    settings_dict: DatabaseConfig, parameters: list[str]
 135) -> tuple[list[str], dict[str, str] | None]:
 136    """Build psql command-line arguments from database settings."""
 137    args = ["psql"]
 138    options = settings_dict.get("OPTIONS", {})
 139
 140    host = settings_dict.get("HOST")
 141    port = settings_dict.get("PORT")
 142    dbname = settings_dict.get("NAME")
 143    user = settings_dict.get("USER")
 144    passwd = settings_dict.get("PASSWORD")
 145    passfile = options.get("passfile")
 146    service = options.get("service")
 147    sslmode = options.get("sslmode")
 148    sslrootcert = options.get("sslrootcert")
 149    sslcert = options.get("sslcert")
 150    sslkey = options.get("sslkey")
 151
 152    if not dbname and not service:
 153        # Connect to the default 'postgres' db.
 154        dbname = "postgres"
 155    if user:
 156        args += ["-U", user]
 157    if host:
 158        args += ["-h", host]
 159    if port:
 160        args += ["-p", str(port)]
 161    args.extend(parameters)
 162    if dbname:
 163        args += [dbname]
 164
 165    env = {}
 166    if passwd:
 167        env["PGPASSWORD"] = str(passwd)
 168    if service:
 169        env["PGSERVICE"] = str(service)
 170    if sslmode:
 171        env["PGSSLMODE"] = str(sslmode)
 172    if sslrootcert:
 173        env["PGSSLROOTCERT"] = str(sslrootcert)
 174    if sslcert:
 175        env["PGSSLCERT"] = str(sslcert)
 176    if sslkey:
 177        env["PGSSLKEY"] = str(sslkey)
 178    if passfile:
 179        env["PGPASSFILE"] = str(passfile)
 180    return args, (env or None)
 181
 182
 183class DatabaseWrapper:
 184    """
 185    PostgreSQL database connection wrapper.
 186
 187    This is the only database backend supported by Plain.
 188    """
 189
 190    queries_limit: int = 9000
 191    executable_name: str = "psql"
 192
 193    index_default_access_method = "btree"
 194    ignored_tables: list[str] = []
 195
 196    # PostgreSQL backend-specific attributes.
 197    _named_cursor_idx = 0
 198
 199    def __init__(self, settings_dict: DatabaseConfig):
 200        # Connection related attributes.
 201        # The underlying database connection (from the database library, not a wrapper).
 202        self.connection: PsycopgConnection[Any] | None = None
 203        # `settings_dict` should be a dictionary containing keys such as
 204        # NAME, USER, etc. It's called `settings_dict` instead of `settings`
 205        # to disambiguate it from Plain settings modules.
 206        self.settings_dict: DatabaseConfig = settings_dict
 207        # Query logging in debug mode or when explicitly enabled.
 208        self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
 209        self.force_debug_cursor: bool = False
 210
 211        # Transaction related attributes.
 212        # Tracks if the connection is in autocommit mode. Per PEP 249, by
 213        # default, it isn't.
 214        self.autocommit: bool = False
 215        # Tracks if the connection is in a transaction managed by 'atomic'.
 216        self.in_atomic_block: bool = False
 217        # Increment to generate unique savepoint ids.
 218        self.savepoint_state: int = 0
 219        # List of savepoints created by 'atomic'.
 220        self.savepoint_ids: list[str] = []
 221        # Stack of active 'atomic' blocks.
 222        self.atomic_blocks: list[Any] = []
 223        # Tracks if the outermost 'atomic' block should commit on exit,
 224        # ie. if autocommit was active on entry.
 225        self.commit_on_exit: bool = True
 226        # Tracks if the transaction should be rolled back to the next
 227        # available savepoint because of an exception in an inner block.
 228        self.needs_rollback: bool = False
 229        self.rollback_exc: Exception | None = None
 230
 231        # Connection termination related attributes.
 232        self.close_at: float | None = None
 233        self.closed_in_transaction: bool = False
 234        self.errors_occurred: bool = False
 235        self.health_check_enabled: bool = False
 236        self.health_check_done: bool = False
 237
 238        # Thread-safety related attributes.
 239        self._thread_sharing_lock: threading.Lock = threading.Lock()
 240        self._thread_sharing_count: int = 0
 241        self._thread_ident: int = _thread.get_ident()
 242
 243        # A list of no-argument functions to run when the transaction commits.
 244        # Each entry is an (sids, func, robust) tuple, where sids is a set of
 245        # the active savepoint IDs when this function was registered and robust
 246        # specifies whether it's allowed for the function to fail.
 247        self.run_on_commit: list[tuple[set[str], Any, bool]] = []
 248
 249        # Should we run the on-commit hooks the next time set_autocommit(True)
 250        # is called?
 251        self.run_commit_hooks_on_set_autocommit_on: bool = False
 252
 253        # A stack of wrappers to be invoked around execute()/executemany()
 254        # calls. Each entry is a function taking five arguments: execute, sql,
 255        # params, many, and context. It's the function's responsibility to
 256        # call execute(sql, params, many, context).
 257        self.execute_wrappers: list[Any] = []
 258
 259    def __repr__(self) -> str:
 260        return f"<{self.__class__.__qualname__} vendor='postgresql'>"
 261
 262    @cached_property
 263    def timezone(self) -> datetime.tzinfo:
 264        """
 265        Return a tzinfo of the database connection time zone.
 266
 267        When a datetime is read from the database, it is returned in this time
 268        zone. Since PostgreSQL supports time zones, it doesn't matter which
 269        time zone Plain uses, as long as aware datetimes are used everywhere.
 270        Other users connecting to the database can choose their own time zone.
 271        """
 272        if self.settings_dict["TIME_ZONE"] is None:
 273            return datetime.UTC
 274        else:
 275            return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
 276
 277    @cached_property
 278    def timezone_name(self) -> str:
 279        """
 280        Name of the time zone of the database connection.
 281        """
 282        if self.settings_dict["TIME_ZONE"] is None:
 283            return "UTC"
 284        else:
 285            return self.settings_dict["TIME_ZONE"]
 286
 287    @property
 288    def queries_logged(self) -> bool:
 289        return self.force_debug_cursor or settings.DEBUG
 290
 291    @property
 292    def queries(self) -> list[dict[str, Any]]:
 293        if len(self.queries_log) == self.queries_log.maxlen:
 294            warnings.warn(
 295                f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
 296                "will be returned."
 297            )
 298        return list(self.queries_log)
 299
 300    def check_database_version_supported(self) -> None:
 301        """
 302        Raise an error if the database version isn't supported by this
 303        version of Plain. PostgreSQL 12+ is required.
 304        """
 305        major, minor = divmod(self.pg_version, 10000)
 306        if major < 12:
 307            raise NotSupportedError(
 308                f"PostgreSQL 12 or later is required (found {major}.{minor})."
 309            )
 310
 311    # ##### Connection and cursor methods #####
 312
 313    def get_connection_params(self) -> dict[str, Any]:
 314        """Return a dict of parameters suitable for get_new_connection."""
 315        settings_dict = self.settings_dict
 316        options = settings_dict.get("OPTIONS", {})
 317        # None may be used to connect to the default 'postgres' db
 318        if settings_dict.get("NAME") == "" and not options.get("service"):
 319            raise ImproperlyConfigured(
 320                "settings.DATABASE is improperly configured. "
 321                "Please supply the NAME or OPTIONS['service'] value."
 322            )
 323        db_name = settings_dict.get("NAME")
 324        if len(db_name or "") > MAX_NAME_LENGTH:
 325            raise ImproperlyConfigured(
 326                "The database name '%s' (%d characters) is longer than "  # noqa: UP031
 327                "PostgreSQL's limit of %d characters. Supply a shorter NAME "
 328                "in settings.DATABASE."
 329                % (
 330                    db_name,
 331                    len(db_name or ""),
 332                    MAX_NAME_LENGTH,
 333                )
 334            )
 335        conn_params: dict[str, Any] = {"client_encoding": "UTF8"}
 336        if db_name:
 337            conn_params = {
 338                "dbname": db_name,
 339                **options,
 340            }
 341        elif db_name is None:
 342            # Connect to the default 'postgres' db.
 343            options.pop("service", None)
 344            conn_params = {"dbname": "postgres", **options}
 345        else:
 346            conn_params = {**options}
 347
 348        conn_params.pop("assume_role", None)
 349        conn_params.pop("isolation_level", None)
 350        conn_params.pop("server_side_binding", None)
 351        if settings_dict["USER"]:
 352            conn_params["user"] = settings_dict["USER"]
 353        if settings_dict["PASSWORD"]:
 354            conn_params["password"] = settings_dict["PASSWORD"]
 355        if settings_dict["HOST"]:
 356            conn_params["host"] = settings_dict["HOST"]
 357        if settings_dict["PORT"]:
 358            conn_params["port"] = settings_dict["PORT"]
 359        conn_params["context"] = get_adapters_template(self.timezone)
 360        # Disable prepared statements by default to keep connection poolers
 361        # working. Can be reenabled via OPTIONS in the settings dict.
 362        conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
 363        return conn_params
 364
 365    def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
 366        """Open a connection to the database."""
 367        # self.isolation_level must be set:
 368        # - after connecting to the database in order to obtain the database's
 369        #   default when no value is explicitly specified in options.
 370        # - before calling _set_autocommit() because if autocommit is on, that
 371        #   will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
 372        options = self.settings_dict.get("OPTIONS", {})
 373        set_isolation_level = False
 374        try:
 375            isolation_level_value = options["isolation_level"]
 376        except KeyError:
 377            self.isolation_level = IsolationLevel.READ_COMMITTED
 378        else:
 379            # Set the isolation level to the value from OPTIONS.
 380            try:
 381                self.isolation_level = IsolationLevel(isolation_level_value)
 382                set_isolation_level = True
 383            except ValueError:
 384                raise ImproperlyConfigured(
 385                    f"Invalid transaction isolation level {isolation_level_value} "
 386                    f"specified. Use one of the psycopg.IsolationLevel values."
 387                )
 388        connection = Database.connect(**conn_params)
 389        if set_isolation_level:
 390            connection.isolation_level = self.isolation_level
 391        # Use server-side binding cursor if requested, otherwise standard cursor
 392        connection.cursor_factory = (
 393            ServerBindingCursor
 394            if options.get("server_side_binding") is True
 395            else Cursor
 396        )
 397        return connection
 398
 399    def ensure_timezone(self) -> bool:
 400        """
 401        Ensure the connection's timezone is set to `self.timezone_name` and
 402        return whether it changed or not.
 403        """
 404        if self.connection is None:
 405            return False
 406        conn_timezone_name = self.connection.info.parameter_status("TimeZone")
 407        timezone_name = self.timezone_name
 408        if timezone_name and conn_timezone_name != timezone_name:
 409            with self.connection.cursor() as cursor:
 410                cursor.execute(
 411                    "SELECT set_config('TimeZone', %s, false)", [timezone_name]
 412                )
 413            return True
 414        return False
 415
 416    def ensure_role(self) -> bool:
 417        if self.connection is None:
 418            return False
 419        if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
 420            with self.connection.cursor() as cursor:
 421                sql_str = self.compose_sql("SET ROLE %s", [new_role])
 422                cursor.execute(sql_str)  # type: ignore[arg-type]
 423            return True
 424        return False
 425
 426    def init_connection_state(self) -> None:
 427        """Initialize the database connection settings."""
 428        global RAN_DB_VERSION_CHECK
 429        if not RAN_DB_VERSION_CHECK:
 430            self.check_database_version_supported()
 431            RAN_DB_VERSION_CHECK = True
 432
 433        # Commit after setting the time zone.
 434        commit_tz = self.ensure_timezone()
 435        # Set the role on the connection. This is useful if the credential used
 436        # to login is not the same as the role that owns database resources. As
 437        # can be the case when using temporary or ephemeral credentials.
 438        commit_role = self.ensure_role()
 439
 440        if (commit_role or commit_tz) and not self.get_autocommit():
 441            assert self.connection is not None
 442            self.connection.commit()
 443
 444    def create_cursor(self, name: str | None = None) -> Any:
 445        """Create a cursor. Assume that a connection is established."""
 446        assert self.connection is not None
 447        if name:
 448            # In autocommit mode, the cursor will be used outside of a
 449            # transaction, hence use a holdable cursor.
 450            cursor = self.connection.cursor(
 451                name, scrollable=False, withhold=self.connection.autocommit
 452            )
 453        else:
 454            cursor = self.connection.cursor()
 455
 456        # Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
 457        tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
 458        if self.timezone != tzloader.timezone:  # type: ignore[union-attr]
 459            register_tzloader(self.timezone, cursor)
 460        return cursor
 461
 462    def chunked_cursor(self) -> utils.CursorWrapper:
 463        """
 464        Return a server-side cursor that avoids caching results in memory.
 465        """
 466        self._named_cursor_idx += 1
 467        # Get the current async task
 468        # Note that right now this is behind @async_unsafe, so this is
 469        # unreachable, but in future we'll start loosening this restriction.
 470        # For now, it's here so that every use of "threading" is
 471        # also async-compatible.
 472        task_ident = "sync"
 473        # Use that and the thread ident to get a unique name
 474        return self._cursor(
 475            name="_plain_curs_%d_%s_%d"  # noqa: UP031
 476            % (
 477                # Avoid reusing name in other threads / tasks
 478                threading.current_thread().ident,
 479                task_ident,
 480                self._named_cursor_idx,
 481            )
 482        )
 483
 484    def _set_autocommit(self, autocommit: bool) -> None:
 485        """Backend-specific implementation to enable or disable autocommit."""
 486        assert self.connection is not None
 487        with self.wrap_database_errors:
 488            self.connection.autocommit = autocommit
 489
 490    def check_constraints(self, table_names: list[str] | None = None) -> None:
 491        """
 492        Check constraints by setting them to immediate. Return them to deferred
 493        afterward.
 494        """
 495        with self.cursor() as cursor:
 496            cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
 497            cursor.execute("SET CONSTRAINTS ALL DEFERRED")
 498
 499    def is_usable(self) -> bool:
 500        """
 501        Test if the database connection is usable.
 502
 503        This method may assume that self.connection is not None.
 504
 505        Actual implementations should take care not to raise exceptions
 506        as that may prevent Plain from recycling unusable connections.
 507        """
 508        assert self.connection is not None
 509        try:
 510            # Use a psycopg cursor directly, bypassing Plain's utilities.
 511            with self.connection.cursor() as cursor:
 512                cursor.execute("SELECT 1")
 513        except Database.Error:
 514            return False
 515        else:
 516            return True
 517
 518    @contextmanager
 519    def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
 520        """
 521        Return a cursor from an alternative connection to be used when there is
 522        no need to access the main database, specifically for test db
 523        creation/deletion. This also prevents the production database from
 524        being exposed to potential child threads while (or after) the test
 525        database is destroyed. Refs #10868, #17786, #16969.
 526        """
 527        cursor = None
 528        try:
 529            conn = self.__class__({**self.settings_dict, "NAME": None})
 530            try:
 531                with conn.cursor() as cursor:
 532                    yield cursor
 533            finally:
 534                conn.close()
 535        except (Database.DatabaseError, WrappedDatabaseError):
 536            if cursor is not None:
 537                raise
 538            warnings.warn(
 539                "Normally Plain will use a connection to the 'postgres' database "
 540                "to avoid running initialization queries against the production "
 541                "database when it's not needed (for example, when running tests). "
 542                "Plain was unable to create a connection to the 'postgres' database "
 543                "and will use the first PostgreSQL database instead.",
 544                RuntimeWarning,
 545            )
 546            conn = self.__class__(
 547                {
 548                    **self.settings_dict,
 549                    "NAME": db_connection.settings_dict["NAME"],
 550                },
 551            )
 552            try:
 553                with conn.cursor() as cursor:
 554                    yield cursor
 555            finally:
 556                conn.close()
 557
 558    @cached_property
 559    def pg_version(self) -> int:
 560        with self.temporary_connection():
 561            assert self.connection is not None
 562            return self.connection.info.server_version
 563
 564    def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
 565        return CursorDebugWrapper(cursor, self)
 566
 567    # ##### Connection lifecycle #####
 568
 569    def connect(self) -> None:
 570        """Connect to the database. Assume that the connection is closed."""
 571        # In case the previous connection was closed while in an atomic block
 572        self.in_atomic_block = False
 573        self.savepoint_ids = []
 574        self.atomic_blocks = []
 575        self.needs_rollback = False
 576        # Reset parameters defining when to close/health-check the connection.
 577        self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
 578        max_age = self.settings_dict["CONN_MAX_AGE"]
 579        self.close_at = None if max_age is None else time.monotonic() + max_age
 580        self.closed_in_transaction = False
 581        self.errors_occurred = False
 582        # New connections are healthy.
 583        self.health_check_done = True
 584        # Establish the connection
 585        conn_params = self.get_connection_params()
 586        self.connection = self.get_new_connection(conn_params)
 587        self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
 588        self.init_connection_state()
 589
 590        self.run_on_commit = []
 591
 592    def ensure_connection(self) -> None:
 593        """Guarantee that a connection to the database is established."""
 594        if self.connection is None:
 595            with self.wrap_database_errors:
 596                self.connect()
 597
 598    # ##### PEP-249 connection method wrappers #####
 599
 600    def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
 601        """
 602        Validate the connection is usable and perform database cursor wrapping.
 603        """
 604        self.validate_thread_sharing()
 605        if self.queries_logged:
 606            wrapped_cursor = self.make_debug_cursor(cursor)
 607        else:
 608            wrapped_cursor = self.make_cursor(cursor)
 609        return wrapped_cursor
 610
 611    def _cursor(self, name: str | None = None) -> utils.CursorWrapper:
 612        self.close_if_health_check_failed()
 613        self.ensure_connection()
 614        with self.wrap_database_errors:
 615            return self._prepare_cursor(self.create_cursor(name))
 616
 617    def _commit(self) -> None:
 618        if self.connection is not None:
 619            with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
 620                return self.connection.commit()
 621
 622    def _rollback(self) -> None:
 623        if self.connection is not None:
 624            with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
 625                return self.connection.rollback()
 626
 627    def _close(self) -> None:
 628        if self.connection is not None:
 629            with self.wrap_database_errors:
 630                return self.connection.close()
 631
 632    # ##### Generic wrappers for PEP-249 connection methods #####
 633
 634    def cursor(self) -> utils.CursorWrapper:
 635        """Create a cursor, opening a connection if necessary."""
 636        return self._cursor()
 637
 638    def commit(self) -> None:
 639        """Commit a transaction and reset the dirty flag."""
 640        self.validate_thread_sharing()
 641        self.validate_no_atomic_block()
 642        self._commit()
 643        # A successful commit means that the database connection works.
 644        self.errors_occurred = False
 645        self.run_commit_hooks_on_set_autocommit_on = True
 646
 647    def rollback(self) -> None:
 648        """Roll back a transaction and reset the dirty flag."""
 649        self.validate_thread_sharing()
 650        self.validate_no_atomic_block()
 651        self._rollback()
 652        # A successful rollback means that the database connection works.
 653        self.errors_occurred = False
 654        self.needs_rollback = False
 655        self.run_on_commit = []
 656
 657    def close(self) -> None:
 658        """Close the connection to the database."""
 659        self.validate_thread_sharing()
 660        self.run_on_commit = []
 661
 662        # Don't call validate_no_atomic_block() to avoid making it difficult
 663        # to get rid of a connection in an invalid state. The next connect()
 664        # will reset the transaction state anyway.
 665        if self.closed_in_transaction or self.connection is None:
 666            return
 667        try:
 668            self._close()
 669        finally:
 670            if self.in_atomic_block:
 671                self.closed_in_transaction = True
 672                self.needs_rollback = True
 673            else:
 674                self.connection = None
 675
 676    # ##### Savepoint management #####
 677
 678    def _savepoint(self, sid: str) -> None:
 679        with self.cursor() as cursor:
 680            cursor.execute(f"SAVEPOINT {quote_name(sid)}")
 681
 682    def _savepoint_rollback(self, sid: str) -> None:
 683        with self.cursor() as cursor:
 684            cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
 685
 686    def _savepoint_commit(self, sid: str) -> None:
 687        with self.cursor() as cursor:
 688            cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
 689
 690    # ##### Generic savepoint management methods #####
 691
 692    def savepoint(self) -> str | None:
 693        """
 694        Create a savepoint inside the current transaction. Return an
 695        identifier for the savepoint that will be used for the subsequent
 696        rollback or commit. Return None if in autocommit mode (no transaction).
 697        """
 698        if self.get_autocommit():
 699            return None
 700
 701        thread_ident = _thread.get_ident()
 702        tid = str(thread_ident).replace("-", "")
 703
 704        self.savepoint_state += 1
 705        sid = "s%s_x%d" % (tid, self.savepoint_state)  # noqa: UP031
 706
 707        self.validate_thread_sharing()
 708        self._savepoint(sid)
 709
 710        return sid
 711
 712    def savepoint_rollback(self, sid: str) -> None:
 713        """
 714        Roll back to a savepoint. Do nothing if in autocommit mode.
 715        """
 716        if self.get_autocommit():
 717            return
 718
 719        self.validate_thread_sharing()
 720        self._savepoint_rollback(sid)
 721
 722        # Remove any callbacks registered while this savepoint was active.
 723        self.run_on_commit = [
 724            (sids, func, robust)
 725            for (sids, func, robust) in self.run_on_commit
 726            if sid not in sids
 727        ]
 728
 729    def savepoint_commit(self, sid: str) -> None:
 730        """
 731        Release a savepoint. Do nothing if in autocommit mode.
 732        """
 733        if self.get_autocommit():
 734            return
 735
 736        self.validate_thread_sharing()
 737        self._savepoint_commit(sid)
 738
 739    def clean_savepoints(self) -> None:
 740        """
 741        Reset the counter used to generate unique savepoint ids in this thread.
 742        """
 743        self.savepoint_state = 0
 744
 745    # ##### Generic transaction management methods #####
 746
 747    def get_autocommit(self) -> bool:
 748        """Get the autocommit state."""
 749        self.ensure_connection()
 750        return self.autocommit
 751
 752    def set_autocommit(self, autocommit: bool) -> None:
 753        """Enable or disable autocommit."""
 754        self.validate_no_atomic_block()
 755        self.close_if_health_check_failed()
 756        self.ensure_connection()
 757
 758        if autocommit:
 759            self._set_autocommit(autocommit)
 760        else:
 761            with debug_transaction(self, "BEGIN"):
 762                self._set_autocommit(autocommit)
 763        self.autocommit = autocommit
 764
 765        if autocommit and self.run_commit_hooks_on_set_autocommit_on:
 766            self.run_and_clear_commit_hooks()
 767            self.run_commit_hooks_on_set_autocommit_on = False
 768
 769    def get_rollback(self) -> bool:
 770        """Get the "needs rollback" flag -- for *advanced use* only."""
 771        if not self.in_atomic_block:
 772            raise TransactionManagementError(
 773                "The rollback flag doesn't work outside of an 'atomic' block."
 774            )
 775        return self.needs_rollback
 776
 777    def set_rollback(self, rollback: bool) -> None:
 778        """
 779        Set or unset the "needs rollback" flag -- for *advanced use* only.
 780        """
 781        if not self.in_atomic_block:
 782            raise TransactionManagementError(
 783                "The rollback flag doesn't work outside of an 'atomic' block."
 784            )
 785        self.needs_rollback = rollback
 786
 787    def validate_no_atomic_block(self) -> None:
 788        """Raise an error if an atomic block is active."""
 789        if self.in_atomic_block:
 790            raise TransactionManagementError(
 791                "This is forbidden when an 'atomic' block is active."
 792            )
 793
 794    def validate_no_broken_transaction(self) -> None:
 795        if self.needs_rollback:
 796            raise TransactionManagementError(
 797                "An error occurred in the current transaction. You can't "
 798                "execute queries until the end of the 'atomic' block."
 799            ) from self.rollback_exc
 800
 801    # ##### Connection termination handling #####
 802
 803    def close_if_health_check_failed(self) -> None:
 804        """Close existing connection if it fails a health check."""
 805        if (
 806            self.connection is None
 807            or not self.health_check_enabled
 808            or self.health_check_done
 809        ):
 810            return
 811
 812        if not self.is_usable():
 813            self.close()
 814        self.health_check_done = True
 815
 816    def close_if_unusable_or_obsolete(self) -> None:
 817        """
 818        Close the current connection if unrecoverable errors have occurred
 819        or if it outlived its maximum age.
 820        """
 821        if self.connection is not None:
 822            self.health_check_done = False
 823            # If the application didn't restore the original autocommit setting,
 824            # don't take chances, drop the connection.
 825            if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
 826                self.close()
 827                return
 828
 829            # If an exception other than DataError or IntegrityError occurred
 830            # since the last commit / rollback, check if the connection works.
 831            if self.errors_occurred:
 832                if self.is_usable():
 833                    self.errors_occurred = False
 834                    self.health_check_done = True
 835                else:
 836                    self.close()
 837                    return
 838
 839            if self.close_at is not None and time.monotonic() >= self.close_at:
 840                self.close()
 841                return
 842
 843    # ##### Thread safety handling #####
 844
 845    @property
 846    def allow_thread_sharing(self) -> bool:
 847        with self._thread_sharing_lock:
 848            return self._thread_sharing_count > 0
 849
 850    def validate_thread_sharing(self) -> None:
 851        """
 852        Validate that the connection isn't accessed by another thread than the
 853        one which originally created it. Raise an exception if the validation
 854        fails.
 855        """
 856        if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
 857            raise DatabaseError(
 858                "DatabaseWrapper objects created in a "
 859                "thread can only be used in that same thread. The connection "
 860                f"was created in thread id {self._thread_ident} and this is "
 861                f"thread id {_thread.get_ident()}."
 862            )
 863
 864    # ##### Miscellaneous #####
 865
 866    @cached_property
 867    def wrap_database_errors(self) -> DatabaseErrorWrapper:
 868        """
 869        Context manager and decorator that re-throws backend-specific database
 870        exceptions using Plain's common wrappers.
 871        """
 872        return DatabaseErrorWrapper(self)
 873
 874    def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
 875        """Create a cursor without debug logging."""
 876        return utils.CursorWrapper(cursor, self)
 877
 878    @contextmanager
 879    def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
 880        """
 881        Context manager that ensures that a connection is established, and
 882        if it opened one, closes it to avoid leaving a dangling connection.
 883        This is useful for operations outside of the request-response cycle.
 884
 885        Provide a cursor: with self.temporary_connection() as cursor: ...
 886        """
 887        must_close = self.connection is None
 888        try:
 889            with self.cursor() as cursor:
 890                yield cursor
 891        finally:
 892            if must_close:
 893                self.close()
 894
 895    def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
 896        """Return a new instance of the schema editor."""
 897        return DatabaseSchemaEditor(self, *args, **kwargs)
 898
 899    def runshell(self, parameters: list[str]) -> None:
 900        """Run an interactive psql shell."""
 901        args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
 902        env = {**os.environ, **env} if env else None
 903        sigint_handler = signal.getsignal(signal.SIGINT)
 904        try:
 905            # Allow SIGINT to pass to psql to abort queries.
 906            signal.signal(signal.SIGINT, signal.SIG_IGN)
 907            subprocess.run(args, env=env, check=True)
 908        finally:
 909            # Restore the original SIGINT handler.
 910            signal.signal(signal.SIGINT, sigint_handler)
 911
 912    def on_commit(self, func: Any, robust: bool = False) -> None:
 913        if not callable(func):
 914            raise TypeError("on_commit()'s callback must be a callable.")
 915        if self.in_atomic_block:
 916            # Transaction in progress; save for execution on commit.
 917            self.run_on_commit.append((set(self.savepoint_ids), func, robust))
 918        elif not self.get_autocommit():
 919            raise TransactionManagementError(
 920                "on_commit() cannot be used in manual transaction management"
 921            )
 922        else:
 923            # No transaction in progress and in autocommit mode; execute
 924            # immediately.
 925            if robust:
 926                try:
 927                    func()
 928                except Exception as e:
 929                    logger.error(
 930                        f"Error calling {func.__qualname__} in on_commit() (%s).",
 931                        e,
 932                        exc_info=True,
 933                    )
 934            else:
 935                func()
 936
 937    def run_and_clear_commit_hooks(self) -> None:
 938        self.validate_no_atomic_block()
 939        current_run_on_commit = self.run_on_commit
 940        self.run_on_commit = []
 941        while current_run_on_commit:
 942            _, func, robust = current_run_on_commit.pop(0)
 943            if robust:
 944                try:
 945                    func()
 946                except Exception as e:
 947                    logger.error(
 948                        f"Error calling {func.__qualname__} in on_commit() during "
 949                        f"transaction (%s).",
 950                        e,
 951                        exc_info=True,
 952                    )
 953            else:
 954                func()
 955
 956    @contextmanager
 957    def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
 958        """
 959        Return a context manager under which the wrapper is applied to suitable
 960        database query executions.
 961        """
 962        self.execute_wrappers.append(wrapper)
 963        try:
 964            yield
 965        finally:
 966            self.execute_wrappers.pop()
 967
 968    # ##### SQL generation methods that require connection state #####
 969
 970    def compose_sql(self, query: str, params: Any) -> str:
 971        """
 972        Compose a SQL query with parameters using psycopg's mogrify.
 973
 974        This requires an active connection because it uses the connection's
 975        cursor to properly format parameters.
 976        """
 977        assert self.connection is not None
 978        return ClientCursor(self.connection).mogrify(
 979            psycopg_sql.SQL(cast(LiteralString, query)), params
 980        )
 981
 982    def last_executed_query(
 983        self,
 984        cursor: utils.CursorWrapper,
 985        sql: str,
 986        params: Any,
 987    ) -> str | None:
 988        """
 989        Return a string of the query last executed by the given cursor, with
 990        placeholders replaced with actual values.
 991        """
 992        try:
 993            return self.compose_sql(sql, params)
 994        except errors.DataError:
 995            return None
 996
 997    def unification_cast_sql(self, output_field: Field) -> str:
 998        """
 999        Given a field instance, return the SQL that casts the result of a union
1000        to that type. The resulting string should contain a '%s' placeholder
1001        for the expression being cast.
1002        """
1003        internal_type = output_field.get_internal_type()
1004        if internal_type in (
1005            "GenericIPAddressField",
1006            "TimeField",
1007            "UUIDField",
1008        ):
1009            # PostgreSQL will resolve a union as type 'text' if input types are
1010            # 'unknown'.
1011            # https://www.postgresql.org/docs/current/typeconv-union-case.html
1012            # These fields cannot be implicitly cast back in the default
1013            # PostgreSQL configuration so we need to explicitly cast them.
1014            # We must also remove components of the type within brackets:
1015            # varchar(255) -> varchar.
1016            db_type = output_field.db_type()
1017            if db_type:
1018                return "CAST(%s AS {})".format(db_type.split("(")[0])
1019        return "%s"
1020
1021    # ##### Introspection methods #####
1022
1023    def table_names(
1024        self, cursor: CursorWrapper | None = None, include_views: bool = False
1025    ) -> list[str]:
1026        """
1027        Return a list of names of all tables that exist in the database.
1028        Sort the returned table list by Python's default sorting. Do NOT use
1029        the database's ORDER BY here to avoid subtle differences in sorting
1030        order between databases.
1031        """
1032
1033        def get_names(cursor: CursorWrapper) -> list[str]:
1034            return sorted(
1035                ti.name
1036                for ti in self.get_table_list(cursor)
1037                if include_views or ti.type == "t"
1038            )
1039
1040        if cursor is None:
1041            with self.cursor() as cursor:
1042                return get_names(cursor)
1043        return get_names(cursor)
1044
1045    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
1046        """
1047        Return an unsorted list of TableInfo named tuples of all tables and
1048        views that exist in the database.
1049        """
1050        cursor.execute(
1051            """
1052            SELECT
1053                c.relname,
1054                CASE
1055                    WHEN c.relispartition THEN 'p'
1056                    WHEN c.relkind IN ('m', 'v') THEN 'v'
1057                    ELSE 't'
1058                END,
1059                obj_description(c.oid, 'pg_class')
1060            FROM pg_catalog.pg_class c
1061            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
1062            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
1063                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
1064                AND pg_catalog.pg_table_is_visible(c.oid)
1065        """
1066        )
1067        return [
1068            TableInfo(*row)
1069            for row in cursor.fetchall()
1070            if row[0] not in self.ignored_tables
1071        ]
1072
1073    def plain_table_names(
1074        self, only_existing: bool = False, include_views: bool = True
1075    ) -> list[str]:
1076        """
1077        Return a list of all table names that have associated Plain models and
1078        are in INSTALLED_PACKAGES.
1079
1080        If only_existing is True, include only the tables in the database.
1081        """
1082        tables = set()
1083        for model in get_migratable_models():
1084            tables.add(model.model_options.db_table)
1085            tables.update(
1086                f.m2m_db_table() for f in model._model_meta.local_many_to_many
1087            )
1088        tables = list(tables)
1089        if only_existing:
1090            existing_tables = set(self.table_names(include_views=include_views))
1091            tables = [t for t in tables if t in existing_tables]
1092        return tables
1093
1094    def get_sequences(
1095        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
1096    ) -> list[dict[str, Any]]:
1097        """
1098        Return a list of introspected sequences for table_name. Each sequence
1099        is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
1100        """
1101        cursor.execute(
1102            """
1103            SELECT
1104                s.relname AS sequence_name,
1105                a.attname AS colname
1106            FROM
1107                pg_class s
1108                JOIN pg_depend d ON d.objid = s.oid
1109                    AND d.classid = 'pg_class'::regclass
1110                    AND d.refclassid = 'pg_class'::regclass
1111                JOIN pg_attribute a ON d.refobjid = a.attrelid
1112                    AND d.refobjsubid = a.attnum
1113                JOIN pg_class tbl ON tbl.oid = d.refobjid
1114                    AND tbl.relname = %s
1115                    AND pg_catalog.pg_table_is_visible(tbl.oid)
1116            WHERE
1117                s.relkind = 'S';
1118        """,
1119            [table_name],
1120        )
1121        return [
1122            {"name": row[0], "table": table_name, "column": row[1]}
1123            for row in cursor.fetchall()
1124        ]
1125
1126    def get_constraints(
1127        self, cursor: CursorWrapper, table_name: str
1128    ) -> dict[str, dict[str, Any]]:
1129        """
1130        Retrieve any constraints or keys (unique, pk, fk, check, index) across
1131        one or more columns. Also retrieve the definition of expression-based
1132        indexes.
1133        """
1134        constraints: dict[str, dict[str, Any]] = {}
1135        # Loop over the key table, collecting things as constraints. The column
1136        # array must return column names in the same order in which they were
1137        # created.
1138        cursor.execute(
1139            """
1140            SELECT
1141                c.conname,
1142                array(
1143                    SELECT attname
1144                    FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
1145                    JOIN pg_attribute AS ca ON cols.colid = ca.attnum
1146                    WHERE ca.attrelid = c.conrelid
1147                    ORDER BY cols.arridx
1148                ),
1149                c.contype,
1150                (SELECT fkc.relname || '.' || fka.attname
1151                FROM pg_attribute AS fka
1152                JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
1153                WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
1154                cl.reloptions
1155            FROM pg_constraint AS c
1156            JOIN pg_class AS cl ON c.conrelid = cl.oid
1157            WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
1158        """,
1159            [table_name],
1160        )
1161        for constraint, columns, kind, used_cols, options in cursor.fetchall():
1162            constraints[constraint] = {
1163                "columns": columns,
1164                "primary_key": kind == "p",
1165                "unique": kind in ["p", "u"],
1166                "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
1167                "check": kind == "c",
1168                "index": False,
1169                "definition": None,
1170                "options": options,
1171            }
1172        # Now get indexes
1173        cursor.execute(
1174            """
1175            SELECT
1176                indexname,
1177                array_agg(attname ORDER BY arridx),
1178                indisunique,
1179                indisprimary,
1180                array_agg(ordering ORDER BY arridx),
1181                amname,
1182                exprdef,
1183                s2.attoptions
1184            FROM (
1185                SELECT
1186                    c2.relname as indexname, idx.*, attr.attname, am.amname,
1187                    CASE
1188                        WHEN idx.indexprs IS NOT NULL THEN
1189                            pg_get_indexdef(idx.indexrelid)
1190                    END AS exprdef,
1191                    CASE am.amname
1192                        WHEN %s THEN
1193                            CASE (option & 1)
1194                                WHEN 1 THEN 'DESC' ELSE 'ASC'
1195                            END
1196                    END as ordering,
1197                    c2.reloptions as attoptions
1198                FROM (
1199                    SELECT *
1200                    FROM
1201                        pg_index i,
1202                        unnest(i.indkey, i.indoption)
1203                            WITH ORDINALITY koi(key, option, arridx)
1204                ) idx
1205                LEFT JOIN pg_class c ON idx.indrelid = c.oid
1206                LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
1207                LEFT JOIN pg_am am ON c2.relam = am.oid
1208                LEFT JOIN
1209                    pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
1210                WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
1211            ) s2
1212            GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
1213        """,
1214            [self.index_default_access_method, table_name],
1215        )
1216        for (
1217            index,
1218            columns,
1219            unique,
1220            primary,
1221            orders,
1222            type_,
1223            definition,
1224            options,
1225        ) in cursor.fetchall():
1226            if index not in constraints:
1227                basic_index = (
1228                    type_ == self.index_default_access_method and options is None
1229                )
1230                constraints[index] = {
1231                    "columns": columns if columns != [None] else [],
1232                    "orders": orders if orders != [None] else [],
1233                    "primary_key": primary,
1234                    "unique": unique,
1235                    "foreign_key": None,
1236                    "check": False,
1237                    "index": True,
1238                    "type": Index.suffix if basic_index else type_,
1239                    "definition": definition,
1240                    "options": options,
1241                }
1242        return constraints
1243
1244    # ##### Test database creation methods (merged from DatabaseCreation) #####
1245
1246    def _log(self, msg: str) -> None:
1247        sys.stderr.write(msg + os.linesep)
1248
1249    def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
1250        """
1251        Create a test database, prompting the user for confirmation if the
1252        database already exists. Return the name of the test database created.
1253
1254        If prefix is provided, it will be prepended to the database name
1255        to isolate it from other test databases.
1256        """
1257        from plain.models.cli.migrations import apply
1258
1259        test_database_name = self._get_test_db_name(prefix)
1260
1261        if verbosity >= 1:
1262            self._log(f"Creating test database '{test_database_name}'...")
1263
1264        self._create_test_db(
1265            test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
1266        )
1267
1268        self.close()
1269        settings.DATABASE["NAME"] = test_database_name
1270        self.settings_dict["NAME"] = test_database_name
1271
1272        apply.callback(
1273            package_label=None,
1274            migration_name=None,
1275            fake=False,
1276            plan=False,
1277            check_unapplied=False,
1278            backup=False,
1279            no_input=True,
1280            atomic_batch=False,  # No need for atomic batch when creating test database
1281            quiet=verbosity < 2,  # Show migration output when verbosity is 2+
1282        )
1283
1284        # Ensure a connection for the side effect of initializing the test database.
1285        self.ensure_connection()
1286
1287        return test_database_name
1288
1289    def _get_test_db_name(self, prefix: str = "") -> str:
1290        """
1291        Internal implementation - return the name of the test DB that will be
1292        created. Only useful when called from create_test_db() and
1293        _create_test_db() and when no external munging is done with the 'NAME'
1294        settings.
1295
1296        If prefix is provided, it will be prepended to the database name.
1297        """
1298        # Determine the base name: explicit TEST.NAME overrides base NAME.
1299        base_name = self.settings_dict["TEST"]["NAME"] or self.settings_dict["NAME"]
1300        if prefix:
1301            return f"{prefix}_{base_name}"
1302        if self.settings_dict["TEST"]["NAME"]:
1303            return self.settings_dict["TEST"]["NAME"]
1304        name = self.settings_dict["NAME"]
1305        assert name is not None, "DATABASE NAME must be set"
1306        return TEST_DATABASE_PREFIX + name
1307
1308    def _get_database_create_suffix(
1309        self, encoding: str | None = None, template: str | None = None
1310    ) -> str:
1311        """Return PostgreSQL-specific CREATE DATABASE suffix."""
1312        suffix = ""
1313        if encoding:
1314            suffix += f" ENCODING '{encoding}'"
1315        if template:
1316            suffix += f" TEMPLATE {quote_name(template)}"
1317        return suffix and "WITH" + suffix
1318
1319    def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
1320        try:
1321            cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
1322        except Exception as e:
1323            cause = e.__cause__
1324            if cause and not isinstance(cause, errors.DuplicateDatabase):
1325                # All errors except "database already exists" cancel tests.
1326                self._log(f"Got an error creating the test database: {e}")
1327                sys.exit(2)
1328            else:
1329                raise
1330
1331    def _create_test_db(
1332        self, *, test_database_name: str, verbosity: int, autoclobber: bool
1333    ) -> str:
1334        """
1335        Internal implementation - create the test db tables.
1336        """
1337        test_db_params = {
1338            "dbname": quote_name(test_database_name),
1339            "suffix": self.sql_table_creation_suffix(),
1340        }
1341        # Create the test database and connect to it.
1342        with self._nodb_cursor() as cursor:
1343            try:
1344                self._execute_create_test_db(cursor, test_db_params)
1345            except Exception as e:
1346                self._log(f"Got an error creating the test database: {e}")
1347                if not autoclobber:
1348                    confirm = input(
1349                        "Type 'yes' if you would like to try deleting the test "
1350                        f"database '{test_database_name}', or 'no' to cancel: "
1351                    )
1352                if autoclobber or confirm == "yes":
1353                    try:
1354                        if verbosity >= 1:
1355                            self._log(
1356                                f"Destroying old test database '{test_database_name}'..."
1357                            )
1358                        cursor.execute(
1359                            "DROP DATABASE {dbname}".format(**test_db_params)
1360                        )
1361                        self._execute_create_test_db(cursor, test_db_params)
1362                    except Exception as e:
1363                        self._log(f"Got an error recreating the test database: {e}")
1364                        sys.exit(2)
1365                else:
1366                    self._log("Tests cancelled.")
1367                    sys.exit(1)
1368
1369        return test_database_name
1370
1371    def destroy_test_db(
1372        self, old_database_name: str | None = None, verbosity: int = 1
1373    ) -> None:
1374        """
1375        Destroy a test database, prompting the user for confirmation if the
1376        database already exists.
1377        """
1378        self.close()
1379
1380        test_database_name = self.settings_dict["NAME"]
1381        assert test_database_name is not None, "Test database NAME must be set"
1382
1383        if verbosity >= 1:
1384            self._log(f"Destroying test database '{test_database_name}'...")
1385        self._destroy_test_db(test_database_name, verbosity)
1386
1387        # Restore the original database name
1388        if old_database_name is not None:
1389            settings.DATABASE["NAME"] = old_database_name
1390            self.settings_dict["NAME"] = old_database_name
1391
1392    def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
1393        """
1394        Internal implementation - remove the test db tables.
1395        """
1396        # Remove the test database to clean up after
1397        # ourselves. Connect to the previous database (not the test database)
1398        # to do so, because it's not allowed to delete a database while being
1399        # connected to it.
1400        with self._nodb_cursor() as cursor:
1401            cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
1402
1403    def sql_table_creation_suffix(self) -> str:
1404        """
1405        SQL to append to the end of the test table creation statements.
1406        """
1407        test_settings = self.settings_dict["TEST"]
1408        return self._get_database_create_suffix(
1409            encoding=test_settings.get("CHARSET"),
1410            template=test_settings.get("TEMPLATE"),
1411        )
1412
1413
1414class CursorMixin:
1415    """
1416    A subclass of psycopg cursor implementing callproc.
1417    """
1418
1419    def callproc(
1420        self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
1421    ) -> list[Any] | None:
1422        if not isinstance(name, psycopg_sql.Identifier):
1423            name = psycopg_sql.Identifier(name)
1424
1425        qparts: list[psycopg_sql.Composable] = [
1426            psycopg_sql.SQL("SELECT * FROM "),
1427            name,
1428            psycopg_sql.SQL("("),
1429        ]
1430        if args:
1431            for item in args:
1432                qparts.append(psycopg_sql.Literal(item))
1433                qparts.append(psycopg_sql.SQL(","))
1434            del qparts[-1]
1435
1436        qparts.append(psycopg_sql.SQL(")"))
1437        stmt = psycopg_sql.Composed(qparts)
1438        self.execute(stmt)  # type: ignore[attr-defined]
1439        return args
1440
1441
1442class ServerBindingCursor(CursorMixin, Database.Cursor):
1443    pass
1444
1445
1446class Cursor(CursorMixin, Database.ClientCursor):
1447    pass
1448
1449
1450class CursorDebugWrapper(BaseCursorDebugWrapper):
1451    def copy(self, statement: Any) -> Any:
1452        with self.debug_sql(statement):
1453            return self.cursor.copy(statement)