1from __future__ import annotations
   2
   3import _thread
   4import datetime
   5import logging
   6import os
   7import signal
   8import subprocess
   9import sys
  10import time
  11import warnings
  12import zoneinfo
  13from collections import deque
  14from collections.abc import Generator, Sequence
  15from contextlib import contextmanager
  16from functools import cached_property, lru_cache
  17from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
  18
  19import psycopg as Database
  20from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
  21from psycopg import sql as psycopg_sql
  22from psycopg.abc import Buffer, PyFormat
  23from psycopg.postgres import types as pg_types
  24from psycopg.pq import Format
  25from psycopg.types.datetime import TimestamptzLoader
  26from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
  27from psycopg.types.string import TextLoader
  28
  29from plain.exceptions import ImproperlyConfigured
  30from plain.models.db import (
  31    DatabaseError,
  32    DatabaseErrorWrapper,
  33)
  34from plain.models.indexes import Index
  35from plain.models.postgres import utils
  36from plain.models.postgres.schema import DatabaseSchemaEditor
  37from plain.models.postgres.sql import MAX_NAME_LENGTH, quote_name
  38from plain.models.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
  39from plain.models.postgres.utils import CursorWrapper, debug_transaction
  40from plain.models.transaction import TransactionManagementError
  41from plain.runtime import settings
  42
  43if TYPE_CHECKING:
  44    from psycopg import Connection as PsycopgConnection
  45
  46    from plain.models.connections import DatabaseConfig
  47    from plain.models.fields import Field
  48
  49logger = logging.getLogger("plain.models.postgres")
  50
  51# The prefix to put on the default database name when creating
  52# the test database.
  53TEST_DATABASE_PREFIX = "test_"
  54
  55
  56def get_migratable_models() -> Generator[Any, None, None]:
  57    """Return all models that should be included in migrations."""
  58    from plain.models import models_registry
  59    from plain.packages import packages_registry
  60
  61    return (
  62        model
  63        for package_config in packages_registry.get_package_configs()
  64        for model in models_registry.get_models(
  65            package_label=package_config.package_label
  66        )
  67    )
  68
  69
  70class TableInfo(NamedTuple):
  71    """Structure returned by DatabaseConnection.get_table_list()."""
  72
  73    name: str
  74    type: str
  75    comment: str | None
  76
  77
  78# Type OIDs
  79TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
  80TSRANGE_OID = pg_types["tsrange"].oid
  81TSTZRANGE_OID = pg_types["tstzrange"].oid
  82
  83
  84class BaseTzLoader(TimestamptzLoader):
  85    """
  86    Load a PostgreSQL timestamptz using a specific timezone.
  87    The timezone can be None too, in which case it will be chopped.
  88    """
  89
  90    timezone: datetime.tzinfo | None = None
  91
  92    def load(self, data: Buffer) -> datetime.datetime:
  93        res = super().load(data)
  94        return res.replace(tzinfo=self.timezone)
  95
  96
  97def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
  98    class SpecificTzLoader(BaseTzLoader):
  99        timezone = tz
 100
 101    context.adapters.register_loader("timestamptz", SpecificTzLoader)
 102
 103
 104class PlainRangeDumper(RangeDumper):
 105    """A Range dumper customized for Plain."""
 106
 107    def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
 108        dumper = super().upgrade(obj, format)
 109        if dumper is not self and dumper.oid == TSRANGE_OID:
 110            dumper.oid = TSTZRANGE_OID
 111        return dumper
 112
 113
 114@lru_cache
 115def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
 116    ctx = adapt.AdaptersMap(adapters)
 117    # No-op JSON loader to avoid psycopg3 round trips
 118    ctx.register_loader("jsonb", TextLoader)
 119    # Treat inet/cidr as text
 120    ctx.register_loader("inet", TextLoader)
 121    ctx.register_loader("cidr", TextLoader)
 122    ctx.register_dumper(Range, PlainRangeDumper)
 123    register_tzloader(timezone, ctx)
 124    return ctx
 125
 126
 127def _psql_settings_to_cmd_args_env(
 128    settings_dict: DatabaseConfig, parameters: list[str]
 129) -> tuple[list[str], dict[str, str] | None]:
 130    """Build psql command-line arguments from database settings."""
 131    args = ["psql"]
 132    options = settings_dict.get("OPTIONS", {})
 133
 134    if user := settings_dict.get("USER"):
 135        args += ["-U", user]
 136    if host := settings_dict.get("HOST"):
 137        args += ["-h", host]
 138    if port := settings_dict.get("PORT"):
 139        args += ["-p", str(port)]
 140    args.extend(parameters)
 141    args += [settings_dict.get("DATABASE") or "postgres"]
 142
 143    env: dict[str, str] = {}
 144    if password := settings_dict.get("PASSWORD"):
 145        env["PGPASSWORD"] = str(password)
 146
 147    # Map OPTIONS keys to their corresponding environment variables.
 148    option_env_vars = {
 149        "passfile": "PGPASSFILE",
 150        "sslmode": "PGSSLMODE",
 151        "sslrootcert": "PGSSLROOTCERT",
 152        "sslcert": "PGSSLCERT",
 153        "sslkey": "PGSSLKEY",
 154    }
 155    for option_key, env_var in option_env_vars.items():
 156        if value := options.get(option_key):
 157            env[env_var] = str(value)
 158
 159    return args, (env or None)
 160
 161
 162class DatabaseConnection:
 163    """
 164    PostgreSQL database connection.
 165
 166    This is the only database backend supported by Plain.
 167    """
 168
 169    queries_limit: int = 9000
 170    executable_name: str = "psql"
 171
 172    index_default_access_method = "btree"
 173    ignored_tables: list[str] = []
 174
 175    def __init__(self, settings_dict: DatabaseConfig):
 176        # Connection related attributes.
 177        # The underlying database connection (from the database library, not a wrapper).
 178        self.connection: PsycopgConnection[Any] | None = None
 179        # `settings_dict` should be a dictionary containing keys such as
 180        # DATABASE, USER, etc. It's called `settings_dict` instead of `settings`
 181        # to disambiguate it from Plain settings modules.
 182        self.settings_dict: DatabaseConfig = settings_dict
 183        # Query logging in debug mode or when explicitly enabled.
 184        self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
 185        self.force_debug_cursor: bool = False
 186
 187        # Transaction related attributes.
 188        # Tracks if the connection is in autocommit mode. Per PEP 249, by
 189        # default, it isn't.
 190        self.autocommit: bool = False
 191        # Tracks if the connection is in a transaction managed by 'atomic'.
 192        self.in_atomic_block: bool = False
 193        # Increment to generate unique savepoint ids.
 194        self.savepoint_state: int = 0
 195        # List of savepoints created by 'atomic'.
 196        self.savepoint_ids: list[str | None] = []
 197        # Stack of active 'atomic' blocks.
 198        self.atomic_blocks: list[Any] = []
 199        # Tracks if the transaction should be rolled back to the next
 200        # available savepoint because of an exception in an inner block.
 201        self.needs_rollback: bool = False
 202        self.rollback_exc: Exception | None = None
 203
 204        # Connection termination related attributes.
 205        self.close_at: float | None = None
 206        self.closed_in_transaction: bool = False
 207        self.errors_occurred: bool = False
 208        self.health_check_enabled: bool = False
 209        self.health_check_done: bool = False
 210
 211        # A list of no-argument functions to run when the transaction commits.
 212        # Each entry is an (sids, func, robust) tuple, where sids is a set of
 213        # the active savepoint IDs when this function was registered and robust
 214        # specifies whether it's allowed for the function to fail.
 215        self.run_on_commit: list[tuple[set[str | None], Any, bool]] = []
 216
 217        # Should we run the on-commit hooks the next time set_autocommit(True)
 218        # is called?
 219        self.run_commit_hooks_on_set_autocommit_on: bool = False
 220
 221        # A stack of wrappers to be invoked around execute()/executemany()
 222        # calls. Each entry is a function taking five arguments: execute, sql,
 223        # params, many, and context. It's the function's responsibility to
 224        # call execute(sql, params, many, context).
 225        self.execute_wrappers: list[Any] = []
 226
 227    def __repr__(self) -> str:
 228        return f"<{self.__class__.__qualname__} vendor='postgresql'>"
 229
 230    @cached_property
 231    def timezone(self) -> datetime.tzinfo:
 232        """
 233        Return a tzinfo of the database connection time zone.
 234
 235        When a datetime is read from the database, it is returned in this time
 236        zone. Since PostgreSQL supports time zones, it doesn't matter which
 237        time zone Plain uses, as long as aware datetimes are used everywhere.
 238        Other users connecting to the database can choose their own time zone.
 239        """
 240        if self.settings_dict["TIME_ZONE"] is None:
 241            return datetime.UTC
 242        return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
 243
 244    @cached_property
 245    def timezone_name(self) -> str:
 246        """
 247        Name of the time zone of the database connection.
 248        """
 249        if self.settings_dict["TIME_ZONE"] is None:
 250            return "UTC"
 251        return self.settings_dict["TIME_ZONE"]
 252
 253    @property
 254    def queries_logged(self) -> bool:
 255        return self.force_debug_cursor or settings.DEBUG
 256
 257    @property
 258    def queries(self) -> list[dict[str, Any]]:
 259        if len(self.queries_log) == self.queries_log.maxlen:
 260            warnings.warn(
 261                f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
 262                "will be returned."
 263            )
 264        return list(self.queries_log)
 265
 266    # ##### Connection and cursor methods #####
 267
 268    def get_connection_params(self) -> dict[str, Any]:
 269        """Return a dict of parameters suitable for get_new_connection."""
 270        settings_dict = self.settings_dict
 271        options = settings_dict.get("OPTIONS", {})
 272        db_name = settings_dict.get("DATABASE")
 273        if db_name == "":
 274            raise ImproperlyConfigured(
 275                "PostgreSQL database is not configured. "
 276                "Set DATABASE_URL or the POSTGRES_DATABASE setting."
 277            )
 278        if len(db_name or "") > MAX_NAME_LENGTH:
 279            raise ImproperlyConfigured(
 280                "The database name '%s' (%d characters) is longer than "  # noqa: UP031
 281                "PostgreSQL's limit of %d characters. Supply a shorter "
 282                "POSTGRES_DATABASE setting."
 283                % (
 284                    db_name,
 285                    len(db_name or ""),
 286                    MAX_NAME_LENGTH,
 287                )
 288            )
 289        if db_name is None:
 290            # None is used to connect to the default 'postgres' db.
 291            db_name = "postgres"
 292        conn_params: dict[str, Any] = {
 293            "dbname": db_name,
 294            **options,
 295        }
 296
 297        conn_params.pop("assume_role", None)
 298        conn_params.pop("isolation_level", None)
 299        conn_params.pop("server_side_binding", None)
 300        if settings_dict["USER"]:
 301            conn_params["user"] = settings_dict["USER"]
 302        if settings_dict["PASSWORD"]:
 303            conn_params["password"] = settings_dict["PASSWORD"]
 304        if settings_dict["HOST"]:
 305            conn_params["host"] = settings_dict["HOST"]
 306        if settings_dict["PORT"]:
 307            conn_params["port"] = settings_dict["PORT"]
 308        conn_params["context"] = get_adapters_template(self.timezone)
 309        # Disable prepared statements by default to keep connection poolers
 310        # working. Can be reenabled via OPTIONS in the settings dict.
 311        conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
 312        return conn_params
 313
 314    def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
 315        """Open a connection to the database."""
 316        # self.isolation_level must be set:
 317        # - after connecting to the database in order to obtain the database's
 318        #   default when no value is explicitly specified in options.
 319        # - before calling _set_autocommit() because if autocommit is on, that
 320        #   will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
 321        options = self.settings_dict.get("OPTIONS", {})
 322        set_isolation_level = False
 323        try:
 324            isolation_level_value = options["isolation_level"]
 325        except KeyError:
 326            self.isolation_level = IsolationLevel.READ_COMMITTED
 327        else:
 328            # Set the isolation level to the value from OPTIONS.
 329            try:
 330                self.isolation_level = IsolationLevel(isolation_level_value)
 331                set_isolation_level = True
 332            except ValueError:
 333                raise ImproperlyConfigured(
 334                    f"Invalid transaction isolation level {isolation_level_value} "
 335                    f"specified. Use one of the psycopg.IsolationLevel values."
 336                )
 337        connection = Database.connect(**conn_params)
 338        if set_isolation_level:
 339            connection.isolation_level = self.isolation_level
 340        # Use server-side binding cursor if requested, otherwise standard cursor
 341        connection.cursor_factory = (
 342            ServerBindingCursor
 343            if options.get("server_side_binding") is True
 344            else Cursor
 345        )
 346        return connection
 347
 348    def ensure_timezone(self) -> bool:
 349        """
 350        Ensure the connection's timezone is set to `self.timezone_name` and
 351        return whether it changed or not.
 352        """
 353        if self.connection is None:
 354            return False
 355        conn_timezone_name = self.connection.info.parameter_status("TimeZone")
 356        timezone_name = self.timezone_name
 357        if timezone_name and conn_timezone_name != timezone_name:
 358            self.connection.execute(
 359                "SELECT set_config('TimeZone', %s, false)", [timezone_name]
 360            )
 361            return True
 362        return False
 363
 364    def ensure_role(self) -> bool:
 365        if self.connection is None:
 366            return False
 367        if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
 368            sql_str = self.compose_sql("SET ROLE %s", [new_role])
 369            self.connection.execute(sql_str)  # type: ignore[arg-type]
 370            return True
 371        return False
 372
 373    def init_connection_state(self) -> None:
 374        """Initialize the database connection settings."""
 375        self.ensure_timezone()
 376        # Set the role on the connection. This is useful if the credential used
 377        # to login is not the same as the role that owns database resources. As
 378        # can be the case when using temporary or ephemeral credentials.
 379        self.ensure_role()
 380
 381    def create_cursor(self) -> Any:
 382        """Create a cursor. Assume that a connection is established."""
 383        assert self.connection is not None
 384        cursor = self.connection.cursor()
 385
 386        # Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
 387        tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
 388        if self.timezone != tzloader.timezone:  # type: ignore[union-attr]
 389            register_tzloader(self.timezone, cursor)
 390        return cursor
 391
 392    def _set_autocommit(self, autocommit: bool) -> None:
 393        """Backend-specific implementation to enable or disable autocommit."""
 394        assert self.connection is not None
 395        with self.wrap_database_errors:
 396            self.connection.autocommit = autocommit
 397
 398    def check_constraints(self, table_names: list[str] | None = None) -> None:
 399        """
 400        Check constraints by setting them to immediate. Return them to deferred
 401        afterward.
 402        """
 403        with self.cursor() as cursor:
 404            cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
 405            cursor.execute("SET CONSTRAINTS ALL DEFERRED")
 406
 407    def is_usable(self) -> bool:
 408        """
 409        Test if the database connection is usable.
 410
 411        This method may assume that self.connection is not None.
 412
 413        Actual implementations should take care not to raise exceptions
 414        as that may prevent Plain from recycling unusable connections.
 415        """
 416        assert self.connection is not None
 417        try:
 418            # Use psycopg directly, bypassing Plain's utilities.
 419            self.connection.execute("SELECT 1")
 420        except Database.Error:
 421            return False
 422        else:
 423            return True
 424
 425    @contextmanager
 426    def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
 427        """
 428        Return a cursor from an alternative connection to be used when there is
 429        no need to access the main database, specifically for test db
 430        creation/deletion. This also prevents the production database from
 431        being exposed to potential child threads while (or after) the test
 432        database is destroyed. Refs #10868, #17786, #16969.
 433        """
 434        cursor = None
 435        try:
 436            conn = self.__class__({**self.settings_dict, "DATABASE": None})
 437            try:
 438                with conn.cursor() as cursor:
 439                    yield cursor
 440            finally:
 441                conn.close()
 442        except (Database.DatabaseError, DatabaseError):
 443            if cursor is not None:
 444                raise
 445            warnings.warn(
 446                "Normally Plain will use a connection to the 'postgres' database "
 447                "to avoid running initialization queries against the production "
 448                "database when it's not needed (for example, when running tests). "
 449                "Plain was unable to create a connection to the 'postgres' database "
 450                "and will use the first PostgreSQL database instead.",
 451                RuntimeWarning,
 452            )
 453            conn = self.__class__(self.settings_dict)
 454            try:
 455                with conn.cursor() as cursor:
 456                    yield cursor
 457            finally:
 458                conn.close()
 459
 460    @cached_property
 461    def pg_version(self) -> int:
 462        with self.temporary_connection():
 463            assert self.connection is not None
 464            return self.connection.info.server_version
 465
 466    def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
 467        return CursorDebugWrapper(cursor, self)
 468
 469    # ##### Connection lifecycle #####
 470
 471    def connect(self) -> None:
 472        """Connect to the database. Assume that the connection is closed."""
 473        # In case the previous connection was closed while in an atomic block
 474        self.in_atomic_block = False
 475        self.savepoint_ids = []
 476        self.atomic_blocks = []
 477        self.needs_rollback = False
 478        # Reset parameters defining when to close/health-check the connection.
 479        self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
 480        max_age = self.settings_dict["CONN_MAX_AGE"]
 481        self.close_at = None if max_age is None else time.monotonic() + max_age
 482        self.closed_in_transaction = False
 483        self.errors_occurred = False
 484        # New connections are healthy.
 485        self.health_check_done = True
 486        # Establish the connection
 487        conn_params = self.get_connection_params()
 488        self.connection = self.get_new_connection(conn_params)
 489        self.set_autocommit(True)
 490        self.init_connection_state()
 491
 492        self.run_on_commit = []
 493
 494    def ensure_connection(self) -> None:
 495        """Guarantee that a connection to the database is established."""
 496        if self.connection is None:
 497            with self.wrap_database_errors:
 498                self.connect()
 499
 500    # ##### PEP-249 connection method wrappers #####
 501
 502    def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
 503        """
 504        Validate the connection is usable and perform database cursor wrapping.
 505        """
 506        if self.queries_logged:
 507            wrapped_cursor = self.make_debug_cursor(cursor)
 508        else:
 509            wrapped_cursor = self.make_cursor(cursor)
 510        return wrapped_cursor
 511
 512    def _cursor(self) -> utils.CursorWrapper:
 513        self.close_if_health_check_failed()
 514        self.ensure_connection()
 515        with self.wrap_database_errors:
 516            return self._prepare_cursor(self.create_cursor())
 517
 518    def _commit(self) -> None:
 519        if self.connection is not None:
 520            with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
 521                return self.connection.commit()
 522
 523    def _rollback(self) -> None:
 524        if self.connection is not None:
 525            with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
 526                return self.connection.rollback()
 527
 528    def _close(self) -> None:
 529        if self.connection is not None:
 530            with self.wrap_database_errors:
 531                return self.connection.close()
 532
 533    # ##### Generic wrappers for PEP-249 connection methods #####
 534
 535    def cursor(self) -> utils.CursorWrapper:
 536        """Create a cursor, opening a connection if necessary."""
 537        return self._cursor()
 538
 539    def commit(self) -> None:
 540        """Commit a transaction and reset the dirty flag."""
 541        self.validate_no_atomic_block()
 542        self._commit()
 543        # A successful commit means that the database connection works.
 544        self.errors_occurred = False
 545        self.run_commit_hooks_on_set_autocommit_on = True
 546
 547    def rollback(self) -> None:
 548        """Roll back a transaction and reset the dirty flag."""
 549        self.validate_no_atomic_block()
 550        self._rollback()
 551        # A successful rollback means that the database connection works.
 552        self.errors_occurred = False
 553        self.needs_rollback = False
 554        self.run_on_commit = []
 555
 556    def close(self) -> None:
 557        """Close the connection to the database."""
 558        self.run_on_commit = []
 559
 560        # Don't call validate_no_atomic_block() to avoid making it difficult
 561        # to get rid of a connection in an invalid state. The next connect()
 562        # will reset the transaction state anyway.
 563        if self.closed_in_transaction or self.connection is None:
 564            return
 565        try:
 566            self._close()
 567        finally:
 568            if self.in_atomic_block:
 569                self.closed_in_transaction = True
 570                self.needs_rollback = True
 571            else:
 572                self.connection = None
 573
 574    # ##### Savepoint management #####
 575
 576    def _savepoint(self, sid: str) -> None:
 577        with self.cursor() as cursor:
 578            cursor.execute(f"SAVEPOINT {quote_name(sid)}")
 579
 580    def _savepoint_rollback(self, sid: str) -> None:
 581        with self.cursor() as cursor:
 582            cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
 583
 584    def _savepoint_commit(self, sid: str) -> None:
 585        with self.cursor() as cursor:
 586            cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
 587
 588    # ##### Generic savepoint management methods #####
 589
 590    def savepoint(self) -> str | None:
 591        """
 592        Create a savepoint inside the current transaction. Return an
 593        identifier for the savepoint that will be used for the subsequent
 594        rollback or commit. Return None if in autocommit mode (no transaction).
 595        """
 596        if self.get_autocommit():
 597            return None
 598
 599        thread_ident = _thread.get_ident()
 600        tid = str(thread_ident).replace("-", "")
 601
 602        self.savepoint_state += 1
 603        sid = "s%s_x%d" % (tid, self.savepoint_state)  # noqa: UP031
 604
 605        self._savepoint(sid)
 606
 607        return sid
 608
 609    def savepoint_rollback(self, sid: str) -> None:
 610        """
 611        Roll back to a savepoint. Do nothing if in autocommit mode.
 612        """
 613        if self.get_autocommit():
 614            return
 615
 616        self._savepoint_rollback(sid)
 617
 618        # Remove any callbacks registered while this savepoint was active.
 619        self.run_on_commit = [
 620            (sids, func, robust)
 621            for (sids, func, robust) in self.run_on_commit
 622            if sid not in sids
 623        ]
 624
 625    def savepoint_commit(self, sid: str) -> None:
 626        """
 627        Release a savepoint. Do nothing if in autocommit mode.
 628        """
 629        if self.get_autocommit():
 630            return
 631
 632        self._savepoint_commit(sid)
 633
 634    def clean_savepoints(self) -> None:
 635        """
 636        Reset the counter used to generate unique savepoint ids in this thread.
 637        """
 638        self.savepoint_state = 0
 639
 640    # ##### Generic transaction management methods #####
 641
 642    def get_autocommit(self) -> bool:
 643        """Get the autocommit state."""
 644        self.ensure_connection()
 645        return self.autocommit
 646
 647    def set_autocommit(self, autocommit: bool) -> None:
 648        """
 649        Enable or disable autocommit.
 650
 651        Used internally by atomic() to manage transactions. Don't call this
 652        directly โ€” use atomic() instead.
 653        """
 654        self.validate_no_atomic_block()
 655        self.close_if_health_check_failed()
 656        self.ensure_connection()
 657
 658        if autocommit:
 659            self._set_autocommit(autocommit)
 660        else:
 661            with debug_transaction(self, "BEGIN"):
 662                self._set_autocommit(autocommit)
 663        self.autocommit = autocommit
 664
 665        if autocommit and self.run_commit_hooks_on_set_autocommit_on:
 666            self.run_and_clear_commit_hooks()
 667            self.run_commit_hooks_on_set_autocommit_on = False
 668
 669    def get_rollback(self) -> bool:
 670        """Get the "needs rollback" flag -- for *advanced use* only."""
 671        if not self.in_atomic_block:
 672            raise TransactionManagementError(
 673                "The rollback flag doesn't work outside of an 'atomic' block."
 674            )
 675        return self.needs_rollback
 676
 677    def set_rollback(self, rollback: bool) -> None:
 678        """
 679        Set or unset the "needs rollback" flag -- for *advanced use* only.
 680        """
 681        if not self.in_atomic_block:
 682            raise TransactionManagementError(
 683                "The rollback flag doesn't work outside of an 'atomic' block."
 684            )
 685        self.needs_rollback = rollback
 686
 687    def validate_no_atomic_block(self) -> None:
 688        """Raise an error if an atomic block is active."""
 689        if self.in_atomic_block:
 690            raise TransactionManagementError(
 691                "This is forbidden when an 'atomic' block is active."
 692            )
 693
 694    def validate_no_broken_transaction(self) -> None:
 695        if self.needs_rollback:
 696            raise TransactionManagementError(
 697                "An error occurred in the current transaction. You can't "
 698                "execute queries until the end of the 'atomic' block."
 699            ) from self.rollback_exc
 700
 701    # ##### Connection termination handling #####
 702
 703    def close_if_health_check_failed(self) -> None:
 704        """Close existing connection if it fails a health check."""
 705        if (
 706            self.connection is None
 707            or not self.health_check_enabled
 708            or self.health_check_done
 709        ):
 710            return
 711
 712        if not self.is_usable():
 713            self.close()
 714        self.health_check_done = True
 715
 716    def close_if_unusable_or_obsolete(self) -> None:
 717        """
 718        Close the current connection if unrecoverable errors have occurred
 719        or if it outlived its maximum age.
 720        """
 721        if self.connection is not None:
 722            self.health_check_done = False
 723            # If autocommit was not restored (e.g. a transaction was not
 724            # properly closed), don't take chances, drop the connection.
 725            if not self.get_autocommit():
 726                self.close()
 727                return
 728
 729            # If an exception other than DataError or IntegrityError occurred
 730            # since the last commit / rollback, check if the connection works.
 731            if self.errors_occurred:
 732                if self.is_usable():
 733                    self.errors_occurred = False
 734                    self.health_check_done = True
 735                else:
 736                    self.close()
 737                    return
 738
 739            if self.close_at is not None and time.monotonic() >= self.close_at:
 740                self.close()
 741                return
 742
 743    # ##### Miscellaneous #####
 744
 745    @cached_property
 746    def wrap_database_errors(self) -> DatabaseErrorWrapper:
 747        """
 748        Context manager and decorator that re-throws backend-specific database
 749        exceptions using Plain's common wrappers.
 750        """
 751        return DatabaseErrorWrapper(self)
 752
 753    def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
 754        """Create a cursor without debug logging."""
 755        return utils.CursorWrapper(cursor, self)
 756
 757    @contextmanager
 758    def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
 759        """
 760        Context manager that ensures that a connection is established, and
 761        if it opened one, closes it to avoid leaving a dangling connection.
 762        This is useful for operations outside of the request-response cycle.
 763
 764        Provide a cursor: with self.temporary_connection() as cursor: ...
 765        """
 766        must_close = self.connection is None
 767        try:
 768            with self.cursor() as cursor:
 769                yield cursor
 770        finally:
 771            if must_close:
 772                self.close()
 773
 774    def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
 775        """Return a new instance of the schema editor."""
 776        return DatabaseSchemaEditor(self, *args, **kwargs)
 777
 778    def runshell(self, parameters: list[str]) -> None:
 779        """Run an interactive psql shell."""
 780        args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
 781        env = {**os.environ, **env} if env else None
 782        sigint_handler = signal.getsignal(signal.SIGINT)
 783        try:
 784            # Allow SIGINT to pass to psql to abort queries.
 785            signal.signal(signal.SIGINT, signal.SIG_IGN)
 786            subprocess.run(args, env=env, check=True)
 787        finally:
 788            # Restore the original SIGINT handler.
 789            signal.signal(signal.SIGINT, sigint_handler)
 790
 791    def on_commit(self, func: Any, robust: bool = False) -> None:
 792        if not callable(func):
 793            raise TypeError("on_commit()'s callback must be a callable.")
 794        if self.in_atomic_block:
 795            # Transaction in progress; save for execution on commit.
 796            self.run_on_commit.append((set(self.savepoint_ids), func, robust))
 797        else:
 798            # No transaction in progress; execute immediately.
 799            if robust:
 800                try:
 801                    func()
 802                except Exception as e:
 803                    logger.error(
 804                        f"Error calling {func.__qualname__} in on_commit() (%s).",
 805                        e,
 806                        exc_info=True,
 807                    )
 808            else:
 809                func()
 810
 811    def run_and_clear_commit_hooks(self) -> None:
 812        self.validate_no_atomic_block()
 813        current_run_on_commit = self.run_on_commit
 814        self.run_on_commit = []
 815        while current_run_on_commit:
 816            _, func, robust = current_run_on_commit.pop(0)
 817            if robust:
 818                try:
 819                    func()
 820                except Exception as e:
 821                    logger.error(
 822                        f"Error calling {func.__qualname__} in on_commit() during "
 823                        f"transaction (%s).",
 824                        e,
 825                        exc_info=True,
 826                    )
 827            else:
 828                func()
 829
 830    @contextmanager
 831    def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
 832        """
 833        Return a context manager under which the wrapper is applied to suitable
 834        database query executions.
 835        """
 836        self.execute_wrappers.append(wrapper)
 837        try:
 838            yield
 839        finally:
 840            self.execute_wrappers.pop()
 841
 842    # ##### SQL generation methods that require connection state #####
 843
 844    def compose_sql(self, query: str, params: Any) -> str:
 845        """
 846        Compose a SQL query with parameters using psycopg's mogrify.
 847
 848        This requires an active connection because it uses the connection's
 849        cursor to properly format parameters.
 850        """
 851        assert self.connection is not None
 852        return ClientCursor(self.connection).mogrify(
 853            psycopg_sql.SQL(cast(LiteralString, query)), params
 854        )
 855
 856    def last_executed_query(
 857        self,
 858        cursor: utils.CursorWrapper,
 859        sql: str,
 860        params: Any,
 861    ) -> str | None:
 862        """
 863        Return a string of the query last executed by the given cursor, with
 864        placeholders replaced with actual values.
 865        """
 866        try:
 867            return self.compose_sql(sql, params)
 868        except errors.DataError:
 869            return None
 870
 871    def unification_cast_sql(self, output_field: Field) -> str:
 872        """
 873        Given a field instance, return the SQL that casts the result of a union
 874        to that type. The resulting string should contain a '%s' placeholder
 875        for the expression being cast.
 876        """
 877        internal_type = output_field.get_internal_type()
 878        if internal_type in (
 879            "GenericIPAddressField",
 880            "TimeField",
 881            "UUIDField",
 882        ):
 883            # PostgreSQL will resolve a union as type 'text' if input types are
 884            # 'unknown'.
 885            # https://www.postgresql.org/docs/current/typeconv-union-case.html
 886            # These fields cannot be implicitly cast back in the default
 887            # PostgreSQL configuration so we need to explicitly cast them.
 888            # We must also remove components of the type within brackets:
 889            # varchar(255) -> varchar.
 890            db_type = output_field.db_type()
 891            if db_type:
 892                return "CAST(%s AS {})".format(db_type.split("(")[0])
 893        return "%s"
 894
 895    # ##### Introspection methods #####
 896
 897    def table_names(
 898        self, cursor: CursorWrapper | None = None, include_views: bool = False
 899    ) -> list[str]:
 900        """
 901        Return a list of names of all tables that exist in the database.
 902        Sort the returned table list by Python's default sorting. Do NOT use
 903        the database's ORDER BY here to avoid subtle differences in sorting
 904        order between databases.
 905        """
 906
 907        def get_names(cursor: CursorWrapper) -> list[str]:
 908            return sorted(
 909                ti.name
 910                for ti in self.get_table_list(cursor)
 911                if include_views or ti.type == "t"
 912            )
 913
 914        if cursor is None:
 915            with self.cursor() as cursor:
 916                return get_names(cursor)
 917        return get_names(cursor)
 918
 919    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
 920        """
 921        Return an unsorted list of TableInfo named tuples of all tables and
 922        views that exist in the database.
 923        """
 924        cursor.execute(
 925            """
 926            SELECT
 927                c.relname,
 928                CASE
 929                    WHEN c.relispartition THEN 'p'
 930                    WHEN c.relkind IN ('m', 'v') THEN 'v'
 931                    ELSE 't'
 932                END,
 933                obj_description(c.oid, 'pg_class')
 934            FROM pg_catalog.pg_class c
 935            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
 936            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
 937                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
 938                AND pg_catalog.pg_table_is_visible(c.oid)
 939        """
 940        )
 941        return [
 942            TableInfo(*row)
 943            for row in cursor.fetchall()
 944            if row[0] not in self.ignored_tables
 945        ]
 946
 947    def plain_table_names(
 948        self, only_existing: bool = False, include_views: bool = True
 949    ) -> list[str]:
 950        """
 951        Return a list of all table names that have associated Plain models and
 952        are in INSTALLED_PACKAGES.
 953
 954        If only_existing is True, include only the tables in the database.
 955        """
 956        tables = set()
 957        for model in get_migratable_models():
 958            tables.add(model.model_options.db_table)
 959            tables.update(
 960                f.m2m_db_table() for f in model._model_meta.local_many_to_many
 961            )
 962        tables = list(tables)
 963        if only_existing:
 964            existing_tables = set(self.table_names(include_views=include_views))
 965            tables = [t for t in tables if t in existing_tables]
 966        return tables
 967
 968    def get_sequences(
 969        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
 970    ) -> list[dict[str, Any]]:
 971        """
 972        Return a list of introspected sequences for table_name. Each sequence
 973        is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
 974        """
 975        cursor.execute(
 976            """
 977            SELECT
 978                s.relname AS sequence_name,
 979                a.attname AS colname
 980            FROM
 981                pg_class s
 982                JOIN pg_depend d ON d.objid = s.oid
 983                    AND d.classid = 'pg_class'::regclass
 984                    AND d.refclassid = 'pg_class'::regclass
 985                JOIN pg_attribute a ON d.refobjid = a.attrelid
 986                    AND d.refobjsubid = a.attnum
 987                JOIN pg_class tbl ON tbl.oid = d.refobjid
 988                    AND tbl.relname = %s
 989                    AND pg_catalog.pg_table_is_visible(tbl.oid)
 990            WHERE
 991                s.relkind = 'S';
 992        """,
 993            [table_name],
 994        )
 995        return [
 996            {"name": row[0], "table": table_name, "column": row[1]}
 997            for row in cursor.fetchall()
 998        ]
 999
1000    def get_constraints(
1001        self, cursor: CursorWrapper, table_name: str
1002    ) -> dict[str, dict[str, Any]]:
1003        """
1004        Retrieve any constraints or keys (unique, pk, fk, check, index) across
1005        one or more columns. Also retrieve the definition of expression-based
1006        indexes.
1007        """
1008        constraints: dict[str, dict[str, Any]] = {}
1009        # Loop over the key table, collecting things as constraints. The column
1010        # array must return column names in the same order in which they were
1011        # created.
1012        cursor.execute(
1013            """
1014            SELECT
1015                c.conname,
1016                array(
1017                    SELECT attname
1018                    FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
1019                    JOIN pg_attribute AS ca ON cols.colid = ca.attnum
1020                    WHERE ca.attrelid = c.conrelid
1021                    ORDER BY cols.arridx
1022                ),
1023                c.contype,
1024                (SELECT fkc.relname || '.' || fka.attname
1025                FROM pg_attribute AS fka
1026                JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
1027                WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
1028                cl.reloptions
1029            FROM pg_constraint AS c
1030            JOIN pg_class AS cl ON c.conrelid = cl.oid
1031            WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
1032        """,
1033            [table_name],
1034        )
1035        for constraint, columns, kind, used_cols, options in cursor.fetchall():
1036            constraints[constraint] = {
1037                "columns": columns,
1038                "primary_key": kind == "p",
1039                "unique": kind in ["p", "u"],
1040                "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
1041                "check": kind == "c",
1042                "index": False,
1043                "definition": None,
1044                "options": options,
1045            }
1046        # Now get indexes
1047        cursor.execute(
1048            """
1049            SELECT
1050                indexname,
1051                array_agg(attname ORDER BY arridx),
1052                indisunique,
1053                indisprimary,
1054                array_agg(ordering ORDER BY arridx),
1055                amname,
1056                exprdef,
1057                s2.attoptions
1058            FROM (
1059                SELECT
1060                    c2.relname as indexname, idx.*, attr.attname, am.amname,
1061                    CASE
1062                        WHEN idx.indexprs IS NOT NULL THEN
1063                            pg_get_indexdef(idx.indexrelid)
1064                    END AS exprdef,
1065                    CASE am.amname
1066                        WHEN %s THEN
1067                            CASE (option & 1)
1068                                WHEN 1 THEN 'DESC' ELSE 'ASC'
1069                            END
1070                    END as ordering,
1071                    c2.reloptions as attoptions
1072                FROM (
1073                    SELECT *
1074                    FROM
1075                        pg_index i,
1076                        unnest(i.indkey, i.indoption)
1077                            WITH ORDINALITY koi(key, option, arridx)
1078                ) idx
1079                LEFT JOIN pg_class c ON idx.indrelid = c.oid
1080                LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
1081                LEFT JOIN pg_am am ON c2.relam = am.oid
1082                LEFT JOIN
1083                    pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
1084                WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
1085            ) s2
1086            GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
1087        """,
1088            [self.index_default_access_method, table_name],
1089        )
1090        for (
1091            index,
1092            columns,
1093            unique,
1094            primary,
1095            orders,
1096            type_,
1097            definition,
1098            options,
1099        ) in cursor.fetchall():
1100            if index not in constraints:
1101                basic_index = (
1102                    type_ == self.index_default_access_method and options is None
1103                )
1104                constraints[index] = {
1105                    "columns": columns if columns != [None] else [],
1106                    "orders": orders if orders != [None] else [],
1107                    "primary_key": primary,
1108                    "unique": unique,
1109                    "foreign_key": None,
1110                    "check": False,
1111                    "index": True,
1112                    "type": Index.suffix if basic_index else type_,
1113                    "definition": definition,
1114                    "options": options,
1115                }
1116        return constraints
1117
1118    # ##### Test database creation methods (merged from DatabaseCreation) #####
1119
1120    def _log(self, msg: str) -> None:
1121        sys.stderr.write(msg + os.linesep)
1122
1123    def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
1124        """
1125        Create a test database, prompting the user for confirmation if the
1126        database already exists. Return the name of the test database created.
1127
1128        If prefix is provided, it will be prepended to the database name
1129        to isolate it from other test databases.
1130        """
1131        from plain.models.cli.migrations import apply
1132
1133        test_database_name = self._get_test_db_name(prefix)
1134
1135        if verbosity >= 1:
1136            self._log(f"Creating test database '{test_database_name}'...")
1137
1138        self._create_test_db(
1139            test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
1140        )
1141
1142        self.close()
1143        settings.POSTGRES_DATABASE = test_database_name
1144        self.settings_dict["DATABASE"] = test_database_name
1145
1146        apply.callback(
1147            package_label=None,
1148            migration_name=None,
1149            fake=False,
1150            plan=False,
1151            check_unapplied=False,
1152            backup=False,
1153            no_input=True,
1154            atomic_batch=False,  # No need for atomic batch when creating test database
1155            quiet=verbosity < 2,  # Show migration output when verbosity is 2+
1156        )
1157
1158        # Ensure a connection for the side effect of initializing the test database.
1159        self.ensure_connection()
1160
1161        return test_database_name
1162
1163    def _get_test_db_name(self, prefix: str = "") -> str:
1164        """
1165        Internal implementation - return the name of the test DB that will be
1166        created. Only useful when called from create_test_db() and
1167        _create_test_db() and when no external munging is done with the 'DATABASE'
1168        settings.
1169
1170        If prefix is provided, it will be prepended to the database name.
1171        """
1172        # Determine the base name: explicit TEST.DATABASE overrides base DATABASE.
1173        base_name = (
1174            self.settings_dict["TEST"]["DATABASE"] or self.settings_dict["DATABASE"]
1175        )
1176        if prefix:
1177            return f"{prefix}_{base_name}"
1178        if self.settings_dict["TEST"]["DATABASE"]:
1179            return self.settings_dict["TEST"]["DATABASE"]
1180        name = self.settings_dict["DATABASE"]
1181        if name is None:
1182            raise ValueError("POSTGRES_DATABASE must be set")
1183        return TEST_DATABASE_PREFIX + name
1184
1185    def _get_database_create_suffix(
1186        self, encoding: str | None = None, template: str | None = None
1187    ) -> str:
1188        """Return PostgreSQL-specific CREATE DATABASE suffix."""
1189        suffix = ""
1190        if encoding:
1191            suffix += f" ENCODING '{encoding}'"
1192        if template:
1193            suffix += f" TEMPLATE {quote_name(template)}"
1194        return suffix and "WITH" + suffix
1195
1196    def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
1197        try:
1198            cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
1199        except Exception as e:
1200            cause = e.__cause__
1201            if cause and not isinstance(cause, errors.DuplicateDatabase):
1202                # All errors except "database already exists" cancel tests.
1203                self._log(f"Got an error creating the test database: {e}")
1204                sys.exit(2)
1205            else:
1206                raise
1207
1208    def _create_test_db(
1209        self, *, test_database_name: str, verbosity: int, autoclobber: bool
1210    ) -> str:
1211        """
1212        Internal implementation - create the test db tables.
1213        """
1214        test_db_params = {
1215            "dbname": quote_name(test_database_name),
1216            "suffix": self.sql_table_creation_suffix(),
1217        }
1218        # Create the test database and connect to it.
1219        with self._nodb_cursor() as cursor:
1220            try:
1221                self._execute_create_test_db(cursor, test_db_params)
1222            except Exception as e:
1223                self._log(f"Got an error creating the test database: {e}")
1224                if not autoclobber:
1225                    confirm = input(
1226                        "Type 'yes' if you would like to try deleting the test "
1227                        f"database '{test_database_name}', or 'no' to cancel: "
1228                    )
1229                if autoclobber or confirm == "yes":
1230                    try:
1231                        if verbosity >= 1:
1232                            self._log(
1233                                f"Destroying old test database '{test_database_name}'..."
1234                            )
1235                        cursor.execute(
1236                            "DROP DATABASE {dbname}".format(**test_db_params)
1237                        )
1238                        self._execute_create_test_db(cursor, test_db_params)
1239                    except Exception as e:
1240                        self._log(f"Got an error recreating the test database: {e}")
1241                        sys.exit(2)
1242                else:
1243                    self._log("Tests cancelled.")
1244                    sys.exit(1)
1245
1246        return test_database_name
1247
1248    def destroy_test_db(
1249        self, old_database_name: str | None = None, verbosity: int = 1
1250    ) -> None:
1251        """
1252        Destroy a test database, prompting the user for confirmation if the
1253        database already exists.
1254        """
1255        self.close()
1256
1257        test_database_name = self.settings_dict["DATABASE"]
1258        if test_database_name is None:
1259            raise ValueError("Test POSTGRES_DATABASE must be set")
1260
1261        if verbosity >= 1:
1262            self._log(f"Destroying test database '{test_database_name}'...")
1263        self._destroy_test_db(test_database_name, verbosity)
1264
1265        # Restore the original database name
1266        if old_database_name is not None:
1267            settings.POSTGRES_DATABASE = old_database_name
1268            self.settings_dict["DATABASE"] = old_database_name
1269
1270    def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
1271        """
1272        Internal implementation - remove the test db tables.
1273        """
1274        # Remove the test database to clean up after
1275        # ourselves. Connect to the previous database (not the test database)
1276        # to do so, because it's not allowed to delete a database while being
1277        # connected to it.
1278        with self._nodb_cursor() as cursor:
1279            cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
1280
1281    def sql_table_creation_suffix(self) -> str:
1282        """
1283        SQL to append to the end of the test table creation statements.
1284        """
1285        test_settings = self.settings_dict["TEST"]
1286        return self._get_database_create_suffix(
1287            encoding=test_settings.get("CHARSET"),
1288            template=test_settings.get("TEMPLATE"),
1289        )
1290
1291
1292class CursorMixin:
1293    """
1294    A subclass of psycopg cursor implementing callproc.
1295    """
1296
1297    def callproc(
1298        self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
1299    ) -> list[Any] | None:
1300        if not isinstance(name, psycopg_sql.Identifier):
1301            name = psycopg_sql.Identifier(name)
1302
1303        qparts: list[psycopg_sql.Composable] = [
1304            psycopg_sql.SQL("SELECT * FROM "),
1305            name,
1306            psycopg_sql.SQL("("),
1307        ]
1308        if args:
1309            for item in args:
1310                qparts.append(psycopg_sql.Literal(item))
1311                qparts.append(psycopg_sql.SQL(","))
1312            del qparts[-1]
1313
1314        qparts.append(psycopg_sql.SQL(")"))
1315        stmt = psycopg_sql.Composed(qparts)
1316        self.execute(stmt)  # type: ignore[attr-defined]
1317        return args
1318
1319
1320class ServerBindingCursor(CursorMixin, Database.Cursor):
1321    pass
1322
1323
1324class Cursor(CursorMixin, Database.ClientCursor):
1325    pass
1326
1327
1328class CursorDebugWrapper(BaseCursorDebugWrapper):
1329    def copy(self, statement: Any) -> Any:
1330        with self.debug_sql(statement):
1331            return self.cursor.copy(statement)