1from __future__ import annotations
   2
   3import _thread
   4import datetime
   5import logging
   6import os
   7import signal
   8import subprocess
   9import sys
  10import threading
  11import time
  12import warnings
  13import zoneinfo
  14from collections import deque
  15from collections.abc import Generator, Sequence
  16from contextlib import contextmanager
  17from functools import cached_property, lru_cache
  18from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
  19
  20import psycopg as Database
  21from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
  22from psycopg import sql as psycopg_sql
  23from psycopg.abc import Buffer, PyFormat
  24from psycopg.postgres import types as pg_types
  25from psycopg.pq import Format
  26from psycopg.types.datetime import TimestamptzLoader
  27from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
  28from psycopg.types.string import TextLoader
  29
  30from plain.exceptions import ImproperlyConfigured
  31from plain.models.db import (
  32    DatabaseError,
  33    DatabaseErrorWrapper,
  34    NotSupportedError,
  35    db_connection,
  36)
  37from plain.models.indexes import Index
  38from plain.models.postgres import utils
  39from plain.models.postgres.schema import DatabaseSchemaEditor
  40from plain.models.postgres.sql import MAX_NAME_LENGTH, quote_name
  41from plain.models.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
  42from plain.models.postgres.utils import CursorWrapper, debug_transaction
  43from plain.models.transaction import TransactionManagementError
  44from plain.runtime import settings
  45
  46if TYPE_CHECKING:
  47    from psycopg import Connection as PsycopgConnection
  48
  49    from plain.models.connections import DatabaseConfig
  50    from plain.models.fields import Field
  51
  52RAN_DB_VERSION_CHECK = False
  53
  54logger = logging.getLogger("plain.models.postgres")
  55
  56# The prefix to put on the default database name when creating
  57# the test database.
  58TEST_DATABASE_PREFIX = "test_"
  59
  60
  61def get_migratable_models() -> Generator[Any, None, None]:
  62    """Return all models that should be included in migrations."""
  63    from plain.models import models_registry
  64    from plain.packages import packages_registry
  65
  66    return (
  67        model
  68        for package_config in packages_registry.get_package_configs()
  69        for model in models_registry.get_models(
  70            package_label=package_config.package_label
  71        )
  72    )
  73
  74
  75class TableInfo(NamedTuple):
  76    """Structure returned by DatabaseWrapper.get_table_list()."""
  77
  78    name: str
  79    type: str
  80    comment: str | None
  81
  82
  83# Type OIDs
  84TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
  85TSRANGE_OID = pg_types["tsrange"].oid
  86TSTZRANGE_OID = pg_types["tstzrange"].oid
  87
  88
  89class BaseTzLoader(TimestamptzLoader):
  90    """
  91    Load a PostgreSQL timestamptz using a specific timezone.
  92    The timezone can be None too, in which case it will be chopped.
  93    """
  94
  95    timezone: datetime.tzinfo | None = None
  96
  97    def load(self, data: Buffer) -> datetime.datetime:
  98        res = super().load(data)
  99        return res.replace(tzinfo=self.timezone)
 100
 101
 102def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
 103    class SpecificTzLoader(BaseTzLoader):
 104        timezone = tz
 105
 106    context.adapters.register_loader("timestamptz", SpecificTzLoader)
 107
 108
 109class PlainRangeDumper(RangeDumper):
 110    """A Range dumper customized for Plain."""
 111
 112    def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
 113        dumper = super().upgrade(obj, format)
 114        if dumper is not self and dumper.oid == TSRANGE_OID:
 115            dumper.oid = TSTZRANGE_OID
 116        return dumper
 117
 118
 119@lru_cache
 120def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
 121    ctx = adapt.AdaptersMap(adapters)
 122    # No-op JSON loader to avoid psycopg3 round trips
 123    ctx.register_loader("jsonb", TextLoader)
 124    # Treat inet/cidr as text
 125    ctx.register_loader("inet", TextLoader)
 126    ctx.register_loader("cidr", TextLoader)
 127    ctx.register_dumper(Range, PlainRangeDumper)
 128    register_tzloader(timezone, ctx)
 129    return ctx
 130
 131
 132def _psql_settings_to_cmd_args_env(
 133    settings_dict: DatabaseConfig, parameters: list[str]
 134) -> tuple[list[str], dict[str, str] | None]:
 135    """Build psql command-line arguments from database settings."""
 136    args = ["psql"]
 137    options = settings_dict.get("OPTIONS", {})
 138
 139    if user := settings_dict.get("USER"):
 140        args += ["-U", user]
 141    if host := settings_dict.get("HOST"):
 142        args += ["-h", host]
 143    if port := settings_dict.get("PORT"):
 144        args += ["-p", str(port)]
 145    args.extend(parameters)
 146    args += [settings_dict.get("DATABASE") or "postgres"]
 147
 148    env: dict[str, str] = {}
 149    if password := settings_dict.get("PASSWORD"):
 150        env["PGPASSWORD"] = str(password)
 151
 152    # Map OPTIONS keys to their corresponding environment variables.
 153    option_env_vars = {
 154        "passfile": "PGPASSFILE",
 155        "sslmode": "PGSSLMODE",
 156        "sslrootcert": "PGSSLROOTCERT",
 157        "sslcert": "PGSSLCERT",
 158        "sslkey": "PGSSLKEY",
 159    }
 160    for option_key, env_var in option_env_vars.items():
 161        if value := options.get(option_key):
 162            env[env_var] = str(value)
 163
 164    return args, (env or None)
 165
 166
 167class DatabaseWrapper:
 168    """
 169    PostgreSQL database connection wrapper.
 170
 171    This is the only database backend supported by Plain.
 172    """
 173
 174    queries_limit: int = 9000
 175    executable_name: str = "psql"
 176
 177    index_default_access_method = "btree"
 178    ignored_tables: list[str] = []
 179
 180    # PostgreSQL backend-specific attributes.
 181    _named_cursor_idx = 0
 182
 183    def __init__(self, settings_dict: DatabaseConfig):
 184        # Connection related attributes.
 185        # The underlying database connection (from the database library, not a wrapper).
 186        self.connection: PsycopgConnection[Any] | None = None
 187        # `settings_dict` should be a dictionary containing keys such as
 188        # DATABASE, USER, etc. It's called `settings_dict` instead of `settings`
 189        # to disambiguate it from Plain settings modules.
 190        self.settings_dict: DatabaseConfig = settings_dict
 191        # Query logging in debug mode or when explicitly enabled.
 192        self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
 193        self.force_debug_cursor: bool = False
 194
 195        # Transaction related attributes.
 196        # Tracks if the connection is in autocommit mode. Per PEP 249, by
 197        # default, it isn't.
 198        self.autocommit: bool = False
 199        # Tracks if the connection is in a transaction managed by 'atomic'.
 200        self.in_atomic_block: bool = False
 201        # Increment to generate unique savepoint ids.
 202        self.savepoint_state: int = 0
 203        # List of savepoints created by 'atomic'.
 204        self.savepoint_ids: list[str] = []
 205        # Stack of active 'atomic' blocks.
 206        self.atomic_blocks: list[Any] = []
 207        # Tracks if the transaction should be rolled back to the next
 208        # available savepoint because of an exception in an inner block.
 209        self.needs_rollback: bool = False
 210        self.rollback_exc: Exception | None = None
 211
 212        # Connection termination related attributes.
 213        self.close_at: float | None = None
 214        self.closed_in_transaction: bool = False
 215        self.errors_occurred: bool = False
 216        self.health_check_enabled: bool = False
 217        self.health_check_done: bool = False
 218
 219        # Thread-safety related attributes.
 220        self._thread_sharing_lock: threading.Lock = threading.Lock()
 221        self._thread_sharing_count: int = 0
 222        self._thread_ident: int = _thread.get_ident()
 223
 224        # A list of no-argument functions to run when the transaction commits.
 225        # Each entry is an (sids, func, robust) tuple, where sids is a set of
 226        # the active savepoint IDs when this function was registered and robust
 227        # specifies whether it's allowed for the function to fail.
 228        self.run_on_commit: list[tuple[set[str], Any, bool]] = []
 229
 230        # Should we run the on-commit hooks the next time set_autocommit(True)
 231        # is called?
 232        self.run_commit_hooks_on_set_autocommit_on: bool = False
 233
 234        # A stack of wrappers to be invoked around execute()/executemany()
 235        # calls. Each entry is a function taking five arguments: execute, sql,
 236        # params, many, and context. It's the function's responsibility to
 237        # call execute(sql, params, many, context).
 238        self.execute_wrappers: list[Any] = []
 239
 240    def __repr__(self) -> str:
 241        return f"<{self.__class__.__qualname__} vendor='postgresql'>"
 242
 243    @cached_property
 244    def timezone(self) -> datetime.tzinfo:
 245        """
 246        Return a tzinfo of the database connection time zone.
 247
 248        When a datetime is read from the database, it is returned in this time
 249        zone. Since PostgreSQL supports time zones, it doesn't matter which
 250        time zone Plain uses, as long as aware datetimes are used everywhere.
 251        Other users connecting to the database can choose their own time zone.
 252        """
 253        if self.settings_dict["TIME_ZONE"] is None:
 254            return datetime.UTC
 255        return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
 256
 257    @cached_property
 258    def timezone_name(self) -> str:
 259        """
 260        Name of the time zone of the database connection.
 261        """
 262        if self.settings_dict["TIME_ZONE"] is None:
 263            return "UTC"
 264        return self.settings_dict["TIME_ZONE"]
 265
 266    @property
 267    def queries_logged(self) -> bool:
 268        return self.force_debug_cursor or settings.DEBUG
 269
 270    @property
 271    def queries(self) -> list[dict[str, Any]]:
 272        if len(self.queries_log) == self.queries_log.maxlen:
 273            warnings.warn(
 274                f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
 275                "will be returned."
 276            )
 277        return list(self.queries_log)
 278
 279    def check_database_version_supported(self) -> None:
 280        """
 281        Raise an error if the database version isn't supported by this
 282        version of Plain. PostgreSQL 12+ is required.
 283        """
 284        major, minor = divmod(self.pg_version, 10000)
 285        if major < 12:
 286            raise NotSupportedError(
 287                f"PostgreSQL 12 or later is required (found {major}.{minor})."
 288            )
 289
 290    # ##### Connection and cursor methods #####
 291
 292    def get_connection_params(self) -> dict[str, Any]:
 293        """Return a dict of parameters suitable for get_new_connection."""
 294        settings_dict = self.settings_dict
 295        options = settings_dict.get("OPTIONS", {})
 296        db_name = settings_dict.get("DATABASE")
 297        if db_name == "":
 298            raise ImproperlyConfigured(
 299                "PostgreSQL database is not configured. "
 300                "Set DATABASE_URL or the POSTGRES_DATABASE setting."
 301            )
 302        if len(db_name or "") > MAX_NAME_LENGTH:
 303            raise ImproperlyConfigured(
 304                "The database name '%s' (%d characters) is longer than "  # noqa: UP031
 305                "PostgreSQL's limit of %d characters. Supply a shorter "
 306                "POSTGRES_DATABASE setting."
 307                % (
 308                    db_name,
 309                    len(db_name or ""),
 310                    MAX_NAME_LENGTH,
 311                )
 312            )
 313        if db_name is None:
 314            # None is used to connect to the default 'postgres' db.
 315            db_name = "postgres"
 316        conn_params: dict[str, Any] = {
 317            "dbname": db_name,
 318            **options,
 319        }
 320
 321        conn_params.pop("assume_role", None)
 322        conn_params.pop("isolation_level", None)
 323        conn_params.pop("server_side_binding", None)
 324        if settings_dict["USER"]:
 325            conn_params["user"] = settings_dict["USER"]
 326        if settings_dict["PASSWORD"]:
 327            conn_params["password"] = settings_dict["PASSWORD"]
 328        if settings_dict["HOST"]:
 329            conn_params["host"] = settings_dict["HOST"]
 330        if settings_dict["PORT"]:
 331            conn_params["port"] = settings_dict["PORT"]
 332        conn_params["context"] = get_adapters_template(self.timezone)
 333        # Disable prepared statements by default to keep connection poolers
 334        # working. Can be reenabled via OPTIONS in the settings dict.
 335        conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
 336        return conn_params
 337
 338    def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
 339        """Open a connection to the database."""
 340        # self.isolation_level must be set:
 341        # - after connecting to the database in order to obtain the database's
 342        #   default when no value is explicitly specified in options.
 343        # - before calling _set_autocommit() because if autocommit is on, that
 344        #   will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
 345        options = self.settings_dict.get("OPTIONS", {})
 346        set_isolation_level = False
 347        try:
 348            isolation_level_value = options["isolation_level"]
 349        except KeyError:
 350            self.isolation_level = IsolationLevel.READ_COMMITTED
 351        else:
 352            # Set the isolation level to the value from OPTIONS.
 353            try:
 354                self.isolation_level = IsolationLevel(isolation_level_value)
 355                set_isolation_level = True
 356            except ValueError:
 357                raise ImproperlyConfigured(
 358                    f"Invalid transaction isolation level {isolation_level_value} "
 359                    f"specified. Use one of the psycopg.IsolationLevel values."
 360                )
 361        connection = Database.connect(**conn_params)
 362        if set_isolation_level:
 363            connection.isolation_level = self.isolation_level
 364        # Use server-side binding cursor if requested, otherwise standard cursor
 365        connection.cursor_factory = (
 366            ServerBindingCursor
 367            if options.get("server_side_binding") is True
 368            else Cursor
 369        )
 370        return connection
 371
 372    def ensure_timezone(self) -> bool:
 373        """
 374        Ensure the connection's timezone is set to `self.timezone_name` and
 375        return whether it changed or not.
 376        """
 377        if self.connection is None:
 378            return False
 379        conn_timezone_name = self.connection.info.parameter_status("TimeZone")
 380        timezone_name = self.timezone_name
 381        if timezone_name and conn_timezone_name != timezone_name:
 382            with self.connection.cursor() as cursor:
 383                cursor.execute(
 384                    "SELECT set_config('TimeZone', %s, false)", [timezone_name]
 385                )
 386            return True
 387        return False
 388
 389    def ensure_role(self) -> bool:
 390        if self.connection is None:
 391            return False
 392        if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
 393            with self.connection.cursor() as cursor:
 394                sql_str = self.compose_sql("SET ROLE %s", [new_role])
 395                cursor.execute(sql_str)  # type: ignore[arg-type]
 396            return True
 397        return False
 398
 399    def init_connection_state(self) -> None:
 400        """Initialize the database connection settings."""
 401        global RAN_DB_VERSION_CHECK
 402        if not RAN_DB_VERSION_CHECK:
 403            self.check_database_version_supported()
 404            RAN_DB_VERSION_CHECK = True
 405
 406        self.ensure_timezone()
 407        # Set the role on the connection. This is useful if the credential used
 408        # to login is not the same as the role that owns database resources. As
 409        # can be the case when using temporary or ephemeral credentials.
 410        self.ensure_role()
 411
 412    def create_cursor(self, name: str | None = None) -> Any:
 413        """Create a cursor. Assume that a connection is established."""
 414        assert self.connection is not None
 415        if name:
 416            # In autocommit mode, the cursor will be used outside of a
 417            # transaction, hence use a holdable cursor.
 418            cursor = self.connection.cursor(
 419                name, scrollable=False, withhold=self.connection.autocommit
 420            )
 421        else:
 422            cursor = self.connection.cursor()
 423
 424        # Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
 425        tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
 426        if self.timezone != tzloader.timezone:  # type: ignore[union-attr]
 427            register_tzloader(self.timezone, cursor)
 428        return cursor
 429
 430    def chunked_cursor(self) -> utils.CursorWrapper:
 431        """
 432        Return a server-side cursor that avoids caching results in memory.
 433        """
 434        self._named_cursor_idx += 1
 435        # Get the current async task
 436        # Note that right now this is behind @async_unsafe, so this is
 437        # unreachable, but in future we'll start loosening this restriction.
 438        # For now, it's here so that every use of "threading" is
 439        # also async-compatible.
 440        task_ident = "sync"
 441        # Use that and the thread ident to get a unique name
 442        return self._cursor(
 443            name="_plain_curs_%d_%s_%d"  # noqa: UP031
 444            % (
 445                # Avoid reusing name in other threads / tasks
 446                threading.current_thread().ident,
 447                task_ident,
 448                self._named_cursor_idx,
 449            )
 450        )
 451
 452    def _set_autocommit(self, autocommit: bool) -> None:
 453        """Backend-specific implementation to enable or disable autocommit."""
 454        assert self.connection is not None
 455        with self.wrap_database_errors:
 456            self.connection.autocommit = autocommit
 457
 458    def check_constraints(self, table_names: list[str] | None = None) -> None:
 459        """
 460        Check constraints by setting them to immediate. Return them to deferred
 461        afterward.
 462        """
 463        with self.cursor() as cursor:
 464            cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
 465            cursor.execute("SET CONSTRAINTS ALL DEFERRED")
 466
 467    def is_usable(self) -> bool:
 468        """
 469        Test if the database connection is usable.
 470
 471        This method may assume that self.connection is not None.
 472
 473        Actual implementations should take care not to raise exceptions
 474        as that may prevent Plain from recycling unusable connections.
 475        """
 476        assert self.connection is not None
 477        try:
 478            # Use a psycopg cursor directly, bypassing Plain's utilities.
 479            with self.connection.cursor() as cursor:
 480                cursor.execute("SELECT 1")
 481        except Database.Error:
 482            return False
 483        else:
 484            return True
 485
 486    @contextmanager
 487    def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
 488        """
 489        Return a cursor from an alternative connection to be used when there is
 490        no need to access the main database, specifically for test db
 491        creation/deletion. This also prevents the production database from
 492        being exposed to potential child threads while (or after) the test
 493        database is destroyed. Refs #10868, #17786, #16969.
 494        """
 495        cursor = None
 496        try:
 497            conn = self.__class__({**self.settings_dict, "DATABASE": None})
 498            try:
 499                with conn.cursor() as cursor:
 500                    yield cursor
 501            finally:
 502                conn.close()
 503        except (Database.DatabaseError, DatabaseError):
 504            if cursor is not None:
 505                raise
 506            warnings.warn(
 507                "Normally Plain will use a connection to the 'postgres' database "
 508                "to avoid running initialization queries against the production "
 509                "database when it's not needed (for example, when running tests). "
 510                "Plain was unable to create a connection to the 'postgres' database "
 511                "and will use the first PostgreSQL database instead.",
 512                RuntimeWarning,
 513            )
 514            conn = self.__class__(
 515                {
 516                    **self.settings_dict,
 517                    "DATABASE": db_connection.settings_dict["DATABASE"],
 518                },
 519            )
 520            try:
 521                with conn.cursor() as cursor:
 522                    yield cursor
 523            finally:
 524                conn.close()
 525
 526    @cached_property
 527    def pg_version(self) -> int:
 528        with self.temporary_connection():
 529            assert self.connection is not None
 530            return self.connection.info.server_version
 531
 532    def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
 533        return CursorDebugWrapper(cursor, self)
 534
 535    # ##### Connection lifecycle #####
 536
 537    def connect(self) -> None:
 538        """Connect to the database. Assume that the connection is closed."""
 539        # In case the previous connection was closed while in an atomic block
 540        self.in_atomic_block = False
 541        self.savepoint_ids = []
 542        self.atomic_blocks = []
 543        self.needs_rollback = False
 544        # Reset parameters defining when to close/health-check the connection.
 545        self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
 546        max_age = self.settings_dict["CONN_MAX_AGE"]
 547        self.close_at = None if max_age is None else time.monotonic() + max_age
 548        self.closed_in_transaction = False
 549        self.errors_occurred = False
 550        # New connections are healthy.
 551        self.health_check_done = True
 552        # Establish the connection
 553        conn_params = self.get_connection_params()
 554        self.connection = self.get_new_connection(conn_params)
 555        self.set_autocommit(True)
 556        self.init_connection_state()
 557
 558        self.run_on_commit = []
 559
 560    def ensure_connection(self) -> None:
 561        """Guarantee that a connection to the database is established."""
 562        if self.connection is None:
 563            with self.wrap_database_errors:
 564                self.connect()
 565
 566    # ##### PEP-249 connection method wrappers #####
 567
 568    def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
 569        """
 570        Validate the connection is usable and perform database cursor wrapping.
 571        """
 572        self.validate_thread_sharing()
 573        if self.queries_logged:
 574            wrapped_cursor = self.make_debug_cursor(cursor)
 575        else:
 576            wrapped_cursor = self.make_cursor(cursor)
 577        return wrapped_cursor
 578
 579    def _cursor(self, name: str | None = None) -> utils.CursorWrapper:
 580        self.close_if_health_check_failed()
 581        self.ensure_connection()
 582        with self.wrap_database_errors:
 583            return self._prepare_cursor(self.create_cursor(name))
 584
 585    def _commit(self) -> None:
 586        if self.connection is not None:
 587            with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
 588                return self.connection.commit()
 589
 590    def _rollback(self) -> None:
 591        if self.connection is not None:
 592            with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
 593                return self.connection.rollback()
 594
 595    def _close(self) -> None:
 596        if self.connection is not None:
 597            with self.wrap_database_errors:
 598                return self.connection.close()
 599
 600    # ##### Generic wrappers for PEP-249 connection methods #####
 601
 602    def cursor(self) -> utils.CursorWrapper:
 603        """Create a cursor, opening a connection if necessary."""
 604        return self._cursor()
 605
 606    def commit(self) -> None:
 607        """Commit a transaction and reset the dirty flag."""
 608        self.validate_thread_sharing()
 609        self.validate_no_atomic_block()
 610        self._commit()
 611        # A successful commit means that the database connection works.
 612        self.errors_occurred = False
 613        self.run_commit_hooks_on_set_autocommit_on = True
 614
 615    def rollback(self) -> None:
 616        """Roll back a transaction and reset the dirty flag."""
 617        self.validate_thread_sharing()
 618        self.validate_no_atomic_block()
 619        self._rollback()
 620        # A successful rollback means that the database connection works.
 621        self.errors_occurred = False
 622        self.needs_rollback = False
 623        self.run_on_commit = []
 624
 625    def close(self) -> None:
 626        """Close the connection to the database."""
 627        self.validate_thread_sharing()
 628        self.run_on_commit = []
 629
 630        # Don't call validate_no_atomic_block() to avoid making it difficult
 631        # to get rid of a connection in an invalid state. The next connect()
 632        # will reset the transaction state anyway.
 633        if self.closed_in_transaction or self.connection is None:
 634            return
 635        try:
 636            self._close()
 637        finally:
 638            if self.in_atomic_block:
 639                self.closed_in_transaction = True
 640                self.needs_rollback = True
 641            else:
 642                self.connection = None
 643
 644    # ##### Savepoint management #####
 645
 646    def _savepoint(self, sid: str) -> None:
 647        with self.cursor() as cursor:
 648            cursor.execute(f"SAVEPOINT {quote_name(sid)}")
 649
 650    def _savepoint_rollback(self, sid: str) -> None:
 651        with self.cursor() as cursor:
 652            cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
 653
 654    def _savepoint_commit(self, sid: str) -> None:
 655        with self.cursor() as cursor:
 656            cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
 657
 658    # ##### Generic savepoint management methods #####
 659
 660    def savepoint(self) -> str | None:
 661        """
 662        Create a savepoint inside the current transaction. Return an
 663        identifier for the savepoint that will be used for the subsequent
 664        rollback or commit. Return None if in autocommit mode (no transaction).
 665        """
 666        if self.get_autocommit():
 667            return None
 668
 669        thread_ident = _thread.get_ident()
 670        tid = str(thread_ident).replace("-", "")
 671
 672        self.savepoint_state += 1
 673        sid = "s%s_x%d" % (tid, self.savepoint_state)  # noqa: UP031
 674
 675        self.validate_thread_sharing()
 676        self._savepoint(sid)
 677
 678        return sid
 679
 680    def savepoint_rollback(self, sid: str) -> None:
 681        """
 682        Roll back to a savepoint. Do nothing if in autocommit mode.
 683        """
 684        if self.get_autocommit():
 685            return
 686
 687        self.validate_thread_sharing()
 688        self._savepoint_rollback(sid)
 689
 690        # Remove any callbacks registered while this savepoint was active.
 691        self.run_on_commit = [
 692            (sids, func, robust)
 693            for (sids, func, robust) in self.run_on_commit
 694            if sid not in sids
 695        ]
 696
 697    def savepoint_commit(self, sid: str) -> None:
 698        """
 699        Release a savepoint. Do nothing if in autocommit mode.
 700        """
 701        if self.get_autocommit():
 702            return
 703
 704        self.validate_thread_sharing()
 705        self._savepoint_commit(sid)
 706
 707    def clean_savepoints(self) -> None:
 708        """
 709        Reset the counter used to generate unique savepoint ids in this thread.
 710        """
 711        self.savepoint_state = 0
 712
 713    # ##### Generic transaction management methods #####
 714
 715    def get_autocommit(self) -> bool:
 716        """Get the autocommit state."""
 717        self.ensure_connection()
 718        return self.autocommit
 719
 720    def set_autocommit(self, autocommit: bool) -> None:
 721        """
 722        Enable or disable autocommit.
 723
 724        Used internally by atomic() to manage transactions. Don't call this
 725        directly โ€” use atomic() instead.
 726        """
 727        self.validate_no_atomic_block()
 728        self.close_if_health_check_failed()
 729        self.ensure_connection()
 730
 731        if autocommit:
 732            self._set_autocommit(autocommit)
 733        else:
 734            with debug_transaction(self, "BEGIN"):
 735                self._set_autocommit(autocommit)
 736        self.autocommit = autocommit
 737
 738        if autocommit and self.run_commit_hooks_on_set_autocommit_on:
 739            self.run_and_clear_commit_hooks()
 740            self.run_commit_hooks_on_set_autocommit_on = False
 741
 742    def get_rollback(self) -> bool:
 743        """Get the "needs rollback" flag -- for *advanced use* only."""
 744        if not self.in_atomic_block:
 745            raise TransactionManagementError(
 746                "The rollback flag doesn't work outside of an 'atomic' block."
 747            )
 748        return self.needs_rollback
 749
 750    def set_rollback(self, rollback: bool) -> None:
 751        """
 752        Set or unset the "needs rollback" flag -- for *advanced use* only.
 753        """
 754        if not self.in_atomic_block:
 755            raise TransactionManagementError(
 756                "The rollback flag doesn't work outside of an 'atomic' block."
 757            )
 758        self.needs_rollback = rollback
 759
 760    def validate_no_atomic_block(self) -> None:
 761        """Raise an error if an atomic block is active."""
 762        if self.in_atomic_block:
 763            raise TransactionManagementError(
 764                "This is forbidden when an 'atomic' block is active."
 765            )
 766
 767    def validate_no_broken_transaction(self) -> None:
 768        if self.needs_rollback:
 769            raise TransactionManagementError(
 770                "An error occurred in the current transaction. You can't "
 771                "execute queries until the end of the 'atomic' block."
 772            ) from self.rollback_exc
 773
 774    # ##### Connection termination handling #####
 775
 776    def close_if_health_check_failed(self) -> None:
 777        """Close existing connection if it fails a health check."""
 778        if (
 779            self.connection is None
 780            or not self.health_check_enabled
 781            or self.health_check_done
 782        ):
 783            return
 784
 785        if not self.is_usable():
 786            self.close()
 787        self.health_check_done = True
 788
 789    def close_if_unusable_or_obsolete(self) -> None:
 790        """
 791        Close the current connection if unrecoverable errors have occurred
 792        or if it outlived its maximum age.
 793        """
 794        if self.connection is not None:
 795            self.health_check_done = False
 796            # If autocommit was not restored (e.g. a transaction was not
 797            # properly closed), don't take chances, drop the connection.
 798            if not self.get_autocommit():
 799                self.close()
 800                return
 801
 802            # If an exception other than DataError or IntegrityError occurred
 803            # since the last commit / rollback, check if the connection works.
 804            if self.errors_occurred:
 805                if self.is_usable():
 806                    self.errors_occurred = False
 807                    self.health_check_done = True
 808                else:
 809                    self.close()
 810                    return
 811
 812            if self.close_at is not None and time.monotonic() >= self.close_at:
 813                self.close()
 814                return
 815
 816    # ##### Thread safety handling #####
 817
 818    @property
 819    def allow_thread_sharing(self) -> bool:
 820        with self._thread_sharing_lock:
 821            return self._thread_sharing_count > 0
 822
 823    def validate_thread_sharing(self) -> None:
 824        """
 825        Validate that the connection isn't accessed by another thread than the
 826        one which originally created it. Raise an exception if the validation
 827        fails.
 828        """
 829        if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
 830            raise DatabaseError(
 831                "DatabaseWrapper objects created in a "
 832                "thread can only be used in that same thread. The connection "
 833                f"was created in thread id {self._thread_ident} and this is "
 834                f"thread id {_thread.get_ident()}."
 835            )
 836
 837    # ##### Miscellaneous #####
 838
 839    @cached_property
 840    def wrap_database_errors(self) -> DatabaseErrorWrapper:
 841        """
 842        Context manager and decorator that re-throws backend-specific database
 843        exceptions using Plain's common wrappers.
 844        """
 845        return DatabaseErrorWrapper(self)
 846
 847    def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
 848        """Create a cursor without debug logging."""
 849        return utils.CursorWrapper(cursor, self)
 850
 851    @contextmanager
 852    def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
 853        """
 854        Context manager that ensures that a connection is established, and
 855        if it opened one, closes it to avoid leaving a dangling connection.
 856        This is useful for operations outside of the request-response cycle.
 857
 858        Provide a cursor: with self.temporary_connection() as cursor: ...
 859        """
 860        must_close = self.connection is None
 861        try:
 862            with self.cursor() as cursor:
 863                yield cursor
 864        finally:
 865            if must_close:
 866                self.close()
 867
 868    def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
 869        """Return a new instance of the schema editor."""
 870        return DatabaseSchemaEditor(self, *args, **kwargs)
 871
 872    def runshell(self, parameters: list[str]) -> None:
 873        """Run an interactive psql shell."""
 874        args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
 875        env = {**os.environ, **env} if env else None
 876        sigint_handler = signal.getsignal(signal.SIGINT)
 877        try:
 878            # Allow SIGINT to pass to psql to abort queries.
 879            signal.signal(signal.SIGINT, signal.SIG_IGN)
 880            subprocess.run(args, env=env, check=True)
 881        finally:
 882            # Restore the original SIGINT handler.
 883            signal.signal(signal.SIGINT, sigint_handler)
 884
 885    def on_commit(self, func: Any, robust: bool = False) -> None:
 886        if not callable(func):
 887            raise TypeError("on_commit()'s callback must be a callable.")
 888        if self.in_atomic_block:
 889            # Transaction in progress; save for execution on commit.
 890            self.run_on_commit.append((set(self.savepoint_ids), func, robust))
 891        else:
 892            # No transaction in progress; execute immediately.
 893            if robust:
 894                try:
 895                    func()
 896                except Exception as e:
 897                    logger.error(
 898                        f"Error calling {func.__qualname__} in on_commit() (%s).",
 899                        e,
 900                        exc_info=True,
 901                    )
 902            else:
 903                func()
 904
 905    def run_and_clear_commit_hooks(self) -> None:
 906        self.validate_no_atomic_block()
 907        current_run_on_commit = self.run_on_commit
 908        self.run_on_commit = []
 909        while current_run_on_commit:
 910            _, func, robust = current_run_on_commit.pop(0)
 911            if robust:
 912                try:
 913                    func()
 914                except Exception as e:
 915                    logger.error(
 916                        f"Error calling {func.__qualname__} in on_commit() during "
 917                        f"transaction (%s).",
 918                        e,
 919                        exc_info=True,
 920                    )
 921            else:
 922                func()
 923
 924    @contextmanager
 925    def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
 926        """
 927        Return a context manager under which the wrapper is applied to suitable
 928        database query executions.
 929        """
 930        self.execute_wrappers.append(wrapper)
 931        try:
 932            yield
 933        finally:
 934            self.execute_wrappers.pop()
 935
 936    # ##### SQL generation methods that require connection state #####
 937
 938    def compose_sql(self, query: str, params: Any) -> str:
 939        """
 940        Compose a SQL query with parameters using psycopg's mogrify.
 941
 942        This requires an active connection because it uses the connection's
 943        cursor to properly format parameters.
 944        """
 945        assert self.connection is not None
 946        return ClientCursor(self.connection).mogrify(
 947            psycopg_sql.SQL(cast(LiteralString, query)), params
 948        )
 949
 950    def last_executed_query(
 951        self,
 952        cursor: utils.CursorWrapper,
 953        sql: str,
 954        params: Any,
 955    ) -> str | None:
 956        """
 957        Return a string of the query last executed by the given cursor, with
 958        placeholders replaced with actual values.
 959        """
 960        try:
 961            return self.compose_sql(sql, params)
 962        except errors.DataError:
 963            return None
 964
 965    def unification_cast_sql(self, output_field: Field) -> str:
 966        """
 967        Given a field instance, return the SQL that casts the result of a union
 968        to that type. The resulting string should contain a '%s' placeholder
 969        for the expression being cast.
 970        """
 971        internal_type = output_field.get_internal_type()
 972        if internal_type in (
 973            "GenericIPAddressField",
 974            "TimeField",
 975            "UUIDField",
 976        ):
 977            # PostgreSQL will resolve a union as type 'text' if input types are
 978            # 'unknown'.
 979            # https://www.postgresql.org/docs/current/typeconv-union-case.html
 980            # These fields cannot be implicitly cast back in the default
 981            # PostgreSQL configuration so we need to explicitly cast them.
 982            # We must also remove components of the type within brackets:
 983            # varchar(255) -> varchar.
 984            db_type = output_field.db_type()
 985            if db_type:
 986                return "CAST(%s AS {})".format(db_type.split("(")[0])
 987        return "%s"
 988
 989    # ##### Introspection methods #####
 990
 991    def table_names(
 992        self, cursor: CursorWrapper | None = None, include_views: bool = False
 993    ) -> list[str]:
 994        """
 995        Return a list of names of all tables that exist in the database.
 996        Sort the returned table list by Python's default sorting. Do NOT use
 997        the database's ORDER BY here to avoid subtle differences in sorting
 998        order between databases.
 999        """
1000
1001        def get_names(cursor: CursorWrapper) -> list[str]:
1002            return sorted(
1003                ti.name
1004                for ti in self.get_table_list(cursor)
1005                if include_views or ti.type == "t"
1006            )
1007
1008        if cursor is None:
1009            with self.cursor() as cursor:
1010                return get_names(cursor)
1011        return get_names(cursor)
1012
1013    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
1014        """
1015        Return an unsorted list of TableInfo named tuples of all tables and
1016        views that exist in the database.
1017        """
1018        cursor.execute(
1019            """
1020            SELECT
1021                c.relname,
1022                CASE
1023                    WHEN c.relispartition THEN 'p'
1024                    WHEN c.relkind IN ('m', 'v') THEN 'v'
1025                    ELSE 't'
1026                END,
1027                obj_description(c.oid, 'pg_class')
1028            FROM pg_catalog.pg_class c
1029            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
1030            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
1031                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
1032                AND pg_catalog.pg_table_is_visible(c.oid)
1033        """
1034        )
1035        return [
1036            TableInfo(*row)
1037            for row in cursor.fetchall()
1038            if row[0] not in self.ignored_tables
1039        ]
1040
1041    def plain_table_names(
1042        self, only_existing: bool = False, include_views: bool = True
1043    ) -> list[str]:
1044        """
1045        Return a list of all table names that have associated Plain models and
1046        are in INSTALLED_PACKAGES.
1047
1048        If only_existing is True, include only the tables in the database.
1049        """
1050        tables = set()
1051        for model in get_migratable_models():
1052            tables.add(model.model_options.db_table)
1053            tables.update(
1054                f.m2m_db_table() for f in model._model_meta.local_many_to_many
1055            )
1056        tables = list(tables)
1057        if only_existing:
1058            existing_tables = set(self.table_names(include_views=include_views))
1059            tables = [t for t in tables if t in existing_tables]
1060        return tables
1061
1062    def get_sequences(
1063        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
1064    ) -> list[dict[str, Any]]:
1065        """
1066        Return a list of introspected sequences for table_name. Each sequence
1067        is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
1068        """
1069        cursor.execute(
1070            """
1071            SELECT
1072                s.relname AS sequence_name,
1073                a.attname AS colname
1074            FROM
1075                pg_class s
1076                JOIN pg_depend d ON d.objid = s.oid
1077                    AND d.classid = 'pg_class'::regclass
1078                    AND d.refclassid = 'pg_class'::regclass
1079                JOIN pg_attribute a ON d.refobjid = a.attrelid
1080                    AND d.refobjsubid = a.attnum
1081                JOIN pg_class tbl ON tbl.oid = d.refobjid
1082                    AND tbl.relname = %s
1083                    AND pg_catalog.pg_table_is_visible(tbl.oid)
1084            WHERE
1085                s.relkind = 'S';
1086        """,
1087            [table_name],
1088        )
1089        return [
1090            {"name": row[0], "table": table_name, "column": row[1]}
1091            for row in cursor.fetchall()
1092        ]
1093
1094    def get_constraints(
1095        self, cursor: CursorWrapper, table_name: str
1096    ) -> dict[str, dict[str, Any]]:
1097        """
1098        Retrieve any constraints or keys (unique, pk, fk, check, index) across
1099        one or more columns. Also retrieve the definition of expression-based
1100        indexes.
1101        """
1102        constraints: dict[str, dict[str, Any]] = {}
1103        # Loop over the key table, collecting things as constraints. The column
1104        # array must return column names in the same order in which they were
1105        # created.
1106        cursor.execute(
1107            """
1108            SELECT
1109                c.conname,
1110                array(
1111                    SELECT attname
1112                    FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
1113                    JOIN pg_attribute AS ca ON cols.colid = ca.attnum
1114                    WHERE ca.attrelid = c.conrelid
1115                    ORDER BY cols.arridx
1116                ),
1117                c.contype,
1118                (SELECT fkc.relname || '.' || fka.attname
1119                FROM pg_attribute AS fka
1120                JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
1121                WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
1122                cl.reloptions
1123            FROM pg_constraint AS c
1124            JOIN pg_class AS cl ON c.conrelid = cl.oid
1125            WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
1126        """,
1127            [table_name],
1128        )
1129        for constraint, columns, kind, used_cols, options in cursor.fetchall():
1130            constraints[constraint] = {
1131                "columns": columns,
1132                "primary_key": kind == "p",
1133                "unique": kind in ["p", "u"],
1134                "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
1135                "check": kind == "c",
1136                "index": False,
1137                "definition": None,
1138                "options": options,
1139            }
1140        # Now get indexes
1141        cursor.execute(
1142            """
1143            SELECT
1144                indexname,
1145                array_agg(attname ORDER BY arridx),
1146                indisunique,
1147                indisprimary,
1148                array_agg(ordering ORDER BY arridx),
1149                amname,
1150                exprdef,
1151                s2.attoptions
1152            FROM (
1153                SELECT
1154                    c2.relname as indexname, idx.*, attr.attname, am.amname,
1155                    CASE
1156                        WHEN idx.indexprs IS NOT NULL THEN
1157                            pg_get_indexdef(idx.indexrelid)
1158                    END AS exprdef,
1159                    CASE am.amname
1160                        WHEN %s THEN
1161                            CASE (option & 1)
1162                                WHEN 1 THEN 'DESC' ELSE 'ASC'
1163                            END
1164                    END as ordering,
1165                    c2.reloptions as attoptions
1166                FROM (
1167                    SELECT *
1168                    FROM
1169                        pg_index i,
1170                        unnest(i.indkey, i.indoption)
1171                            WITH ORDINALITY koi(key, option, arridx)
1172                ) idx
1173                LEFT JOIN pg_class c ON idx.indrelid = c.oid
1174                LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
1175                LEFT JOIN pg_am am ON c2.relam = am.oid
1176                LEFT JOIN
1177                    pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
1178                WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
1179            ) s2
1180            GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
1181        """,
1182            [self.index_default_access_method, table_name],
1183        )
1184        for (
1185            index,
1186            columns,
1187            unique,
1188            primary,
1189            orders,
1190            type_,
1191            definition,
1192            options,
1193        ) in cursor.fetchall():
1194            if index not in constraints:
1195                basic_index = (
1196                    type_ == self.index_default_access_method and options is None
1197                )
1198                constraints[index] = {
1199                    "columns": columns if columns != [None] else [],
1200                    "orders": orders if orders != [None] else [],
1201                    "primary_key": primary,
1202                    "unique": unique,
1203                    "foreign_key": None,
1204                    "check": False,
1205                    "index": True,
1206                    "type": Index.suffix if basic_index else type_,
1207                    "definition": definition,
1208                    "options": options,
1209                }
1210        return constraints
1211
1212    # ##### Test database creation methods (merged from DatabaseCreation) #####
1213
1214    def _log(self, msg: str) -> None:
1215        sys.stderr.write(msg + os.linesep)
1216
1217    def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
1218        """
1219        Create a test database, prompting the user for confirmation if the
1220        database already exists. Return the name of the test database created.
1221
1222        If prefix is provided, it will be prepended to the database name
1223        to isolate it from other test databases.
1224        """
1225        from plain.models.cli.migrations import apply
1226
1227        test_database_name = self._get_test_db_name(prefix)
1228
1229        if verbosity >= 1:
1230            self._log(f"Creating test database '{test_database_name}'...")
1231
1232        self._create_test_db(
1233            test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
1234        )
1235
1236        self.close()
1237        settings.POSTGRES_DATABASE = test_database_name
1238        self.settings_dict["DATABASE"] = test_database_name
1239
1240        apply.callback(
1241            package_label=None,
1242            migration_name=None,
1243            fake=False,
1244            plan=False,
1245            check_unapplied=False,
1246            backup=False,
1247            no_input=True,
1248            atomic_batch=False,  # No need for atomic batch when creating test database
1249            quiet=verbosity < 2,  # Show migration output when verbosity is 2+
1250        )
1251
1252        # Ensure a connection for the side effect of initializing the test database.
1253        self.ensure_connection()
1254
1255        return test_database_name
1256
1257    def _get_test_db_name(self, prefix: str = "") -> str:
1258        """
1259        Internal implementation - return the name of the test DB that will be
1260        created. Only useful when called from create_test_db() and
1261        _create_test_db() and when no external munging is done with the 'DATABASE'
1262        settings.
1263
1264        If prefix is provided, it will be prepended to the database name.
1265        """
1266        # Determine the base name: explicit TEST.DATABASE overrides base DATABASE.
1267        base_name = (
1268            self.settings_dict["TEST"]["DATABASE"] or self.settings_dict["DATABASE"]
1269        )
1270        if prefix:
1271            return f"{prefix}_{base_name}"
1272        if self.settings_dict["TEST"]["DATABASE"]:
1273            return self.settings_dict["TEST"]["DATABASE"]
1274        name = self.settings_dict["DATABASE"]
1275        if name is None:
1276            raise ValueError("POSTGRES_DATABASE must be set")
1277        return TEST_DATABASE_PREFIX + name
1278
1279    def _get_database_create_suffix(
1280        self, encoding: str | None = None, template: str | None = None
1281    ) -> str:
1282        """Return PostgreSQL-specific CREATE DATABASE suffix."""
1283        suffix = ""
1284        if encoding:
1285            suffix += f" ENCODING '{encoding}'"
1286        if template:
1287            suffix += f" TEMPLATE {quote_name(template)}"
1288        return suffix and "WITH" + suffix
1289
1290    def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
1291        try:
1292            cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
1293        except Exception as e:
1294            cause = e.__cause__
1295            if cause and not isinstance(cause, errors.DuplicateDatabase):
1296                # All errors except "database already exists" cancel tests.
1297                self._log(f"Got an error creating the test database: {e}")
1298                sys.exit(2)
1299            else:
1300                raise
1301
1302    def _create_test_db(
1303        self, *, test_database_name: str, verbosity: int, autoclobber: bool
1304    ) -> str:
1305        """
1306        Internal implementation - create the test db tables.
1307        """
1308        test_db_params = {
1309            "dbname": quote_name(test_database_name),
1310            "suffix": self.sql_table_creation_suffix(),
1311        }
1312        # Create the test database and connect to it.
1313        with self._nodb_cursor() as cursor:
1314            try:
1315                self._execute_create_test_db(cursor, test_db_params)
1316            except Exception as e:
1317                self._log(f"Got an error creating the test database: {e}")
1318                if not autoclobber:
1319                    confirm = input(
1320                        "Type 'yes' if you would like to try deleting the test "
1321                        f"database '{test_database_name}', or 'no' to cancel: "
1322                    )
1323                if autoclobber or confirm == "yes":
1324                    try:
1325                        if verbosity >= 1:
1326                            self._log(
1327                                f"Destroying old test database '{test_database_name}'..."
1328                            )
1329                        cursor.execute(
1330                            "DROP DATABASE {dbname}".format(**test_db_params)
1331                        )
1332                        self._execute_create_test_db(cursor, test_db_params)
1333                    except Exception as e:
1334                        self._log(f"Got an error recreating the test database: {e}")
1335                        sys.exit(2)
1336                else:
1337                    self._log("Tests cancelled.")
1338                    sys.exit(1)
1339
1340        return test_database_name
1341
1342    def destroy_test_db(
1343        self, old_database_name: str | None = None, verbosity: int = 1
1344    ) -> None:
1345        """
1346        Destroy a test database, prompting the user for confirmation if the
1347        database already exists.
1348        """
1349        self.close()
1350
1351        test_database_name = self.settings_dict["DATABASE"]
1352        if test_database_name is None:
1353            raise ValueError("Test POSTGRES_DATABASE must be set")
1354
1355        if verbosity >= 1:
1356            self._log(f"Destroying test database '{test_database_name}'...")
1357        self._destroy_test_db(test_database_name, verbosity)
1358
1359        # Restore the original database name
1360        if old_database_name is not None:
1361            settings.POSTGRES_DATABASE = old_database_name
1362            self.settings_dict["DATABASE"] = old_database_name
1363
1364    def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
1365        """
1366        Internal implementation - remove the test db tables.
1367        """
1368        # Remove the test database to clean up after
1369        # ourselves. Connect to the previous database (not the test database)
1370        # to do so, because it's not allowed to delete a database while being
1371        # connected to it.
1372        with self._nodb_cursor() as cursor:
1373            cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
1374
1375    def sql_table_creation_suffix(self) -> str:
1376        """
1377        SQL to append to the end of the test table creation statements.
1378        """
1379        test_settings = self.settings_dict["TEST"]
1380        return self._get_database_create_suffix(
1381            encoding=test_settings.get("CHARSET"),
1382            template=test_settings.get("TEMPLATE"),
1383        )
1384
1385
1386class CursorMixin:
1387    """
1388    A subclass of psycopg cursor implementing callproc.
1389    """
1390
1391    def callproc(
1392        self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
1393    ) -> list[Any] | None:
1394        if not isinstance(name, psycopg_sql.Identifier):
1395            name = psycopg_sql.Identifier(name)
1396
1397        qparts: list[psycopg_sql.Composable] = [
1398            psycopg_sql.SQL("SELECT * FROM "),
1399            name,
1400            psycopg_sql.SQL("("),
1401        ]
1402        if args:
1403            for item in args:
1404                qparts.append(psycopg_sql.Literal(item))
1405                qparts.append(psycopg_sql.SQL(","))
1406            del qparts[-1]
1407
1408        qparts.append(psycopg_sql.SQL(")"))
1409        stmt = psycopg_sql.Composed(qparts)
1410        self.execute(stmt)  # type: ignore[attr-defined]
1411        return args
1412
1413
1414class ServerBindingCursor(CursorMixin, Database.Cursor):
1415    pass
1416
1417
1418class Cursor(CursorMixin, Database.ClientCursor):
1419    pass
1420
1421
1422class CursorDebugWrapper(BaseCursorDebugWrapper):
1423    def copy(self, statement: Any) -> Any:
1424        with self.debug_sql(statement):
1425            return self.cursor.copy(statement)