1from __future__ import annotations
2
3import _thread
4import datetime
5import logging
6import os
7import signal
8import subprocess
9import sys
10import time
11import warnings
12import zoneinfo
13from collections import deque
14from collections.abc import Generator, Sequence
15from contextlib import contextmanager
16from functools import cached_property, lru_cache
17from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
18
19import psycopg as Database
20from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
21from psycopg import sql as psycopg_sql
22from psycopg.abc import Buffer, PyFormat
23from psycopg.postgres import types as pg_types
24from psycopg.pq import Format
25from psycopg.types.datetime import TimestamptzLoader
26from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
27from psycopg.types.string import TextLoader
28
29from plain.exceptions import ImproperlyConfigured
30from plain.models.db import (
31 DatabaseError,
32 DatabaseErrorWrapper,
33)
34from plain.models.indexes import Index
35from plain.models.postgres import utils
36from plain.models.postgres.schema import DatabaseSchemaEditor
37from plain.models.postgres.sql import MAX_NAME_LENGTH, quote_name
38from plain.models.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
39from plain.models.postgres.utils import CursorWrapper, debug_transaction
40from plain.models.transaction import TransactionManagementError
41from plain.runtime import settings
42
43if TYPE_CHECKING:
44 from psycopg import Connection as PsycopgConnection
45
46 from plain.models.connections import DatabaseConfig
47 from plain.models.fields import Field
48
49logger = logging.getLogger("plain.models.postgres")
50
51# The prefix to put on the default database name when creating
52# the test database.
53TEST_DATABASE_PREFIX = "test_"
54
55
56def get_migratable_models() -> Generator[Any, None, None]:
57 """Return all models that should be included in migrations."""
58 from plain.models import models_registry
59 from plain.packages import packages_registry
60
61 return (
62 model
63 for package_config in packages_registry.get_package_configs()
64 for model in models_registry.get_models(
65 package_label=package_config.package_label
66 )
67 )
68
69
70class TableInfo(NamedTuple):
71 """Structure returned by DatabaseConnection.get_table_list()."""
72
73 name: str
74 type: str
75 comment: str | None
76
77
78# Type OIDs
79TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
80TSRANGE_OID = pg_types["tsrange"].oid
81TSTZRANGE_OID = pg_types["tstzrange"].oid
82
83
84class BaseTzLoader(TimestamptzLoader):
85 """
86 Load a PostgreSQL timestamptz using a specific timezone.
87 The timezone can be None too, in which case it will be chopped.
88 """
89
90 timezone: datetime.tzinfo | None = None
91
92 def load(self, data: Buffer) -> datetime.datetime:
93 res = super().load(data)
94 return res.replace(tzinfo=self.timezone)
95
96
97def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
98 class SpecificTzLoader(BaseTzLoader):
99 timezone = tz
100
101 context.adapters.register_loader("timestamptz", SpecificTzLoader)
102
103
104class PlainRangeDumper(RangeDumper):
105 """A Range dumper customized for Plain."""
106
107 def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
108 dumper = super().upgrade(obj, format)
109 if dumper is not self and dumper.oid == TSRANGE_OID:
110 dumper.oid = TSTZRANGE_OID
111 return dumper
112
113
114@lru_cache
115def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
116 ctx = adapt.AdaptersMap(adapters)
117 # No-op JSON loader to avoid psycopg3 round trips
118 ctx.register_loader("jsonb", TextLoader)
119 # Treat inet/cidr as text
120 ctx.register_loader("inet", TextLoader)
121 ctx.register_loader("cidr", TextLoader)
122 ctx.register_dumper(Range, PlainRangeDumper)
123 register_tzloader(timezone, ctx)
124 return ctx
125
126
127def _psql_settings_to_cmd_args_env(
128 settings_dict: DatabaseConfig, parameters: list[str]
129) -> tuple[list[str], dict[str, str] | None]:
130 """Build psql command-line arguments from database settings."""
131 args = ["psql"]
132 options = settings_dict.get("OPTIONS", {})
133
134 if user := settings_dict.get("USER"):
135 args += ["-U", user]
136 if host := settings_dict.get("HOST"):
137 args += ["-h", host]
138 if port := settings_dict.get("PORT"):
139 args += ["-p", str(port)]
140 args.extend(parameters)
141 args += [settings_dict.get("DATABASE") or "postgres"]
142
143 env: dict[str, str] = {}
144 if password := settings_dict.get("PASSWORD"):
145 env["PGPASSWORD"] = str(password)
146
147 # Map OPTIONS keys to their corresponding environment variables.
148 option_env_vars = {
149 "passfile": "PGPASSFILE",
150 "sslmode": "PGSSLMODE",
151 "sslrootcert": "PGSSLROOTCERT",
152 "sslcert": "PGSSLCERT",
153 "sslkey": "PGSSLKEY",
154 }
155 for option_key, env_var in option_env_vars.items():
156 if value := options.get(option_key):
157 env[env_var] = str(value)
158
159 return args, (env or None)
160
161
162class DatabaseConnection:
163 """
164 PostgreSQL database connection.
165
166 This is the only database backend supported by Plain.
167 """
168
169 queries_limit: int = 9000
170 executable_name: str = "psql"
171
172 index_default_access_method = "btree"
173 ignored_tables: list[str] = []
174
175 def __init__(self, settings_dict: DatabaseConfig):
176 # Connection related attributes.
177 # The underlying database connection (from the database library, not a wrapper).
178 self.connection: PsycopgConnection[Any] | None = None
179 # `settings_dict` should be a dictionary containing keys such as
180 # DATABASE, USER, etc. It's called `settings_dict` instead of `settings`
181 # to disambiguate it from Plain settings modules.
182 self.settings_dict: DatabaseConfig = settings_dict
183 # Query logging in debug mode or when explicitly enabled.
184 self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
185 self.force_debug_cursor: bool = False
186
187 # Transaction related attributes.
188 # Tracks if the connection is in autocommit mode. Per PEP 249, by
189 # default, it isn't.
190 self.autocommit: bool = False
191 # Tracks if the connection is in a transaction managed by 'atomic'.
192 self.in_atomic_block: bool = False
193 # Increment to generate unique savepoint ids.
194 self.savepoint_state: int = 0
195 # List of savepoints created by 'atomic'.
196 self.savepoint_ids: list[str | None] = []
197 # Stack of active 'atomic' blocks.
198 self.atomic_blocks: list[Any] = []
199 # Tracks if the transaction should be rolled back to the next
200 # available savepoint because of an exception in an inner block.
201 self.needs_rollback: bool = False
202 self.rollback_exc: Exception | None = None
203
204 # Connection termination related attributes.
205 self.close_at: float | None = None
206 self.closed_in_transaction: bool = False
207 self.errors_occurred: bool = False
208 self.health_check_enabled: bool = False
209 self.health_check_done: bool = False
210
211 # A list of no-argument functions to run when the transaction commits.
212 # Each entry is an (sids, func, robust) tuple, where sids is a set of
213 # the active savepoint IDs when this function was registered and robust
214 # specifies whether it's allowed for the function to fail.
215 self.run_on_commit: list[tuple[set[str | None], Any, bool]] = []
216
217 # Should we run the on-commit hooks the next time set_autocommit(True)
218 # is called?
219 self.run_commit_hooks_on_set_autocommit_on: bool = False
220
221 # A stack of wrappers to be invoked around execute()/executemany()
222 # calls. Each entry is a function taking five arguments: execute, sql,
223 # params, many, and context. It's the function's responsibility to
224 # call execute(sql, params, many, context).
225 self.execute_wrappers: list[Any] = []
226
227 def __repr__(self) -> str:
228 return f"<{self.__class__.__qualname__} vendor='postgresql'>"
229
230 @cached_property
231 def timezone(self) -> datetime.tzinfo:
232 """
233 Return a tzinfo of the database connection time zone.
234
235 When a datetime is read from the database, it is returned in this time
236 zone. Since PostgreSQL supports time zones, it doesn't matter which
237 time zone Plain uses, as long as aware datetimes are used everywhere.
238 Other users connecting to the database can choose their own time zone.
239 """
240 if self.settings_dict["TIME_ZONE"] is None:
241 return datetime.UTC
242 return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
243
244 @cached_property
245 def timezone_name(self) -> str:
246 """
247 Name of the time zone of the database connection.
248 """
249 if self.settings_dict["TIME_ZONE"] is None:
250 return "UTC"
251 return self.settings_dict["TIME_ZONE"]
252
253 @property
254 def queries_logged(self) -> bool:
255 return self.force_debug_cursor or settings.DEBUG
256
257 @property
258 def queries(self) -> list[dict[str, Any]]:
259 if len(self.queries_log) == self.queries_log.maxlen:
260 warnings.warn(
261 f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
262 "will be returned."
263 )
264 return list(self.queries_log)
265
266 # ##### Connection and cursor methods #####
267
268 def get_connection_params(self) -> dict[str, Any]:
269 """Return a dict of parameters suitable for get_new_connection."""
270 settings_dict = self.settings_dict
271 options = settings_dict.get("OPTIONS", {})
272 db_name = settings_dict.get("DATABASE")
273 if db_name == "":
274 raise ImproperlyConfigured(
275 "PostgreSQL database is not configured. "
276 "Set DATABASE_URL or the POSTGRES_DATABASE setting."
277 )
278 if len(db_name or "") > MAX_NAME_LENGTH:
279 raise ImproperlyConfigured(
280 "The database name '%s' (%d characters) is longer than " # noqa: UP031
281 "PostgreSQL's limit of %d characters. Supply a shorter "
282 "POSTGRES_DATABASE setting."
283 % (
284 db_name,
285 len(db_name or ""),
286 MAX_NAME_LENGTH,
287 )
288 )
289 if db_name is None:
290 # None is used to connect to the default 'postgres' db.
291 db_name = "postgres"
292 conn_params: dict[str, Any] = {
293 "dbname": db_name,
294 **options,
295 }
296
297 conn_params.pop("assume_role", None)
298 conn_params.pop("isolation_level", None)
299 conn_params.pop("server_side_binding", None)
300 if settings_dict["USER"]:
301 conn_params["user"] = settings_dict["USER"]
302 if settings_dict["PASSWORD"]:
303 conn_params["password"] = settings_dict["PASSWORD"]
304 if settings_dict["HOST"]:
305 conn_params["host"] = settings_dict["HOST"]
306 if settings_dict["PORT"]:
307 conn_params["port"] = settings_dict["PORT"]
308 conn_params["context"] = get_adapters_template(self.timezone)
309 # Disable prepared statements by default to keep connection poolers
310 # working. Can be reenabled via OPTIONS in the settings dict.
311 conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
312 return conn_params
313
314 def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
315 """Open a connection to the database."""
316 # self.isolation_level must be set:
317 # - after connecting to the database in order to obtain the database's
318 # default when no value is explicitly specified in options.
319 # - before calling _set_autocommit() because if autocommit is on, that
320 # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
321 options = self.settings_dict.get("OPTIONS", {})
322 set_isolation_level = False
323 try:
324 isolation_level_value = options["isolation_level"]
325 except KeyError:
326 self.isolation_level = IsolationLevel.READ_COMMITTED
327 else:
328 # Set the isolation level to the value from OPTIONS.
329 try:
330 self.isolation_level = IsolationLevel(isolation_level_value)
331 set_isolation_level = True
332 except ValueError:
333 raise ImproperlyConfigured(
334 f"Invalid transaction isolation level {isolation_level_value} "
335 f"specified. Use one of the psycopg.IsolationLevel values."
336 )
337 connection = Database.connect(**conn_params)
338 if set_isolation_level:
339 connection.isolation_level = self.isolation_level
340 # Use server-side binding cursor if requested, otherwise standard cursor
341 connection.cursor_factory = (
342 ServerBindingCursor
343 if options.get("server_side_binding") is True
344 else Cursor
345 )
346 return connection
347
348 def ensure_timezone(self) -> bool:
349 """
350 Ensure the connection's timezone is set to `self.timezone_name` and
351 return whether it changed or not.
352 """
353 if self.connection is None:
354 return False
355 conn_timezone_name = self.connection.info.parameter_status("TimeZone")
356 timezone_name = self.timezone_name
357 if timezone_name and conn_timezone_name != timezone_name:
358 self.connection.execute(
359 "SELECT set_config('TimeZone', %s, false)", [timezone_name]
360 )
361 return True
362 return False
363
364 def ensure_role(self) -> bool:
365 if self.connection is None:
366 return False
367 if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
368 sql_str = self.compose_sql("SET ROLE %s", [new_role])
369 self.connection.execute(sql_str) # type: ignore[arg-type]
370 return True
371 return False
372
373 def init_connection_state(self) -> None:
374 """Initialize the database connection settings."""
375 self.ensure_timezone()
376 # Set the role on the connection. This is useful if the credential used
377 # to login is not the same as the role that owns database resources. As
378 # can be the case when using temporary or ephemeral credentials.
379 self.ensure_role()
380
381 def create_cursor(self) -> Any:
382 """Create a cursor. Assume that a connection is established."""
383 assert self.connection is not None
384 cursor = self.connection.cursor()
385
386 # Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
387 tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
388 if self.timezone != tzloader.timezone: # type: ignore[union-attr]
389 register_tzloader(self.timezone, cursor)
390 return cursor
391
392 def _set_autocommit(self, autocommit: bool) -> None:
393 """Backend-specific implementation to enable or disable autocommit."""
394 assert self.connection is not None
395 with self.wrap_database_errors:
396 self.connection.autocommit = autocommit
397
398 def check_constraints(self, table_names: list[str] | None = None) -> None:
399 """
400 Check constraints by setting them to immediate. Return them to deferred
401 afterward.
402 """
403 with self.cursor() as cursor:
404 cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
405 cursor.execute("SET CONSTRAINTS ALL DEFERRED")
406
407 def is_usable(self) -> bool:
408 """
409 Test if the database connection is usable.
410
411 This method may assume that self.connection is not None.
412
413 Actual implementations should take care not to raise exceptions
414 as that may prevent Plain from recycling unusable connections.
415 """
416 assert self.connection is not None
417 try:
418 # Use psycopg directly, bypassing Plain's utilities.
419 self.connection.execute("SELECT 1")
420 except Database.Error:
421 return False
422 else:
423 return True
424
425 @contextmanager
426 def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
427 """
428 Return a cursor from an alternative connection to be used when there is
429 no need to access the main database, specifically for test db
430 creation/deletion. This also prevents the production database from
431 being exposed to potential child threads while (or after) the test
432 database is destroyed. Refs #10868, #17786, #16969.
433 """
434 cursor = None
435 try:
436 conn = self.__class__({**self.settings_dict, "DATABASE": None})
437 try:
438 with conn.cursor() as cursor:
439 yield cursor
440 finally:
441 conn.close()
442 except (Database.DatabaseError, DatabaseError):
443 if cursor is not None:
444 raise
445 warnings.warn(
446 "Normally Plain will use a connection to the 'postgres' database "
447 "to avoid running initialization queries against the production "
448 "database when it's not needed (for example, when running tests). "
449 "Plain was unable to create a connection to the 'postgres' database "
450 "and will use the first PostgreSQL database instead.",
451 RuntimeWarning,
452 )
453 conn = self.__class__(self.settings_dict)
454 try:
455 with conn.cursor() as cursor:
456 yield cursor
457 finally:
458 conn.close()
459
460 @cached_property
461 def pg_version(self) -> int:
462 with self.temporary_connection():
463 assert self.connection is not None
464 return self.connection.info.server_version
465
466 def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
467 return CursorDebugWrapper(cursor, self)
468
469 # ##### Connection lifecycle #####
470
471 def connect(self) -> None:
472 """Connect to the database. Assume that the connection is closed."""
473 # In case the previous connection was closed while in an atomic block
474 self.in_atomic_block = False
475 self.savepoint_ids = []
476 self.atomic_blocks = []
477 self.needs_rollback = False
478 # Reset parameters defining when to close/health-check the connection.
479 self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
480 max_age = self.settings_dict["CONN_MAX_AGE"]
481 self.close_at = None if max_age is None else time.monotonic() + max_age
482 self.closed_in_transaction = False
483 self.errors_occurred = False
484 # New connections are healthy.
485 self.health_check_done = True
486 # Establish the connection
487 conn_params = self.get_connection_params()
488 self.connection = self.get_new_connection(conn_params)
489 self.set_autocommit(True)
490 self.init_connection_state()
491
492 self.run_on_commit = []
493
494 def ensure_connection(self) -> None:
495 """Guarantee that a connection to the database is established."""
496 if self.connection is None:
497 with self.wrap_database_errors:
498 self.connect()
499
500 # ##### PEP-249 connection method wrappers #####
501
502 def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
503 """
504 Validate the connection is usable and perform database cursor wrapping.
505 """
506 if self.queries_logged:
507 wrapped_cursor = self.make_debug_cursor(cursor)
508 else:
509 wrapped_cursor = self.make_cursor(cursor)
510 return wrapped_cursor
511
512 def _cursor(self) -> utils.CursorWrapper:
513 self.close_if_health_check_failed()
514 self.ensure_connection()
515 with self.wrap_database_errors:
516 return self._prepare_cursor(self.create_cursor())
517
518 def _commit(self) -> None:
519 if self.connection is not None:
520 with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
521 return self.connection.commit()
522
523 def _rollback(self) -> None:
524 if self.connection is not None:
525 with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
526 return self.connection.rollback()
527
528 def _close(self) -> None:
529 if self.connection is not None:
530 with self.wrap_database_errors:
531 return self.connection.close()
532
533 # ##### Generic wrappers for PEP-249 connection methods #####
534
535 def cursor(self) -> utils.CursorWrapper:
536 """Create a cursor, opening a connection if necessary."""
537 return self._cursor()
538
539 def commit(self) -> None:
540 """Commit a transaction and reset the dirty flag."""
541 self.validate_no_atomic_block()
542 self._commit()
543 # A successful commit means that the database connection works.
544 self.errors_occurred = False
545 self.run_commit_hooks_on_set_autocommit_on = True
546
547 def rollback(self) -> None:
548 """Roll back a transaction and reset the dirty flag."""
549 self.validate_no_atomic_block()
550 self._rollback()
551 # A successful rollback means that the database connection works.
552 self.errors_occurred = False
553 self.needs_rollback = False
554 self.run_on_commit = []
555
556 def close(self) -> None:
557 """Close the connection to the database."""
558 self.run_on_commit = []
559
560 # Don't call validate_no_atomic_block() to avoid making it difficult
561 # to get rid of a connection in an invalid state. The next connect()
562 # will reset the transaction state anyway.
563 if self.closed_in_transaction or self.connection is None:
564 return
565 try:
566 self._close()
567 finally:
568 if self.in_atomic_block:
569 self.closed_in_transaction = True
570 self.needs_rollback = True
571 else:
572 self.connection = None
573
574 # ##### Savepoint management #####
575
576 def _savepoint(self, sid: str) -> None:
577 with self.cursor() as cursor:
578 cursor.execute(f"SAVEPOINT {quote_name(sid)}")
579
580 def _savepoint_rollback(self, sid: str) -> None:
581 with self.cursor() as cursor:
582 cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
583
584 def _savepoint_commit(self, sid: str) -> None:
585 with self.cursor() as cursor:
586 cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
587
588 # ##### Generic savepoint management methods #####
589
590 def savepoint(self) -> str | None:
591 """
592 Create a savepoint inside the current transaction. Return an
593 identifier for the savepoint that will be used for the subsequent
594 rollback or commit. Return None if in autocommit mode (no transaction).
595 """
596 if self.get_autocommit():
597 return None
598
599 thread_ident = _thread.get_ident()
600 tid = str(thread_ident).replace("-", "")
601
602 self.savepoint_state += 1
603 sid = "s%s_x%d" % (tid, self.savepoint_state) # noqa: UP031
604
605 self._savepoint(sid)
606
607 return sid
608
609 def savepoint_rollback(self, sid: str) -> None:
610 """
611 Roll back to a savepoint. Do nothing if in autocommit mode.
612 """
613 if self.get_autocommit():
614 return
615
616 self._savepoint_rollback(sid)
617
618 # Remove any callbacks registered while this savepoint was active.
619 self.run_on_commit = [
620 (sids, func, robust)
621 for (sids, func, robust) in self.run_on_commit
622 if sid not in sids
623 ]
624
625 def savepoint_commit(self, sid: str) -> None:
626 """
627 Release a savepoint. Do nothing if in autocommit mode.
628 """
629 if self.get_autocommit():
630 return
631
632 self._savepoint_commit(sid)
633
634 def clean_savepoints(self) -> None:
635 """
636 Reset the counter used to generate unique savepoint ids in this thread.
637 """
638 self.savepoint_state = 0
639
640 # ##### Generic transaction management methods #####
641
642 def get_autocommit(self) -> bool:
643 """Get the autocommit state."""
644 self.ensure_connection()
645 return self.autocommit
646
647 def set_autocommit(self, autocommit: bool) -> None:
648 """
649 Enable or disable autocommit.
650
651 Used internally by atomic() to manage transactions. Don't call this
652 directly โ use atomic() instead.
653 """
654 self.validate_no_atomic_block()
655 self.close_if_health_check_failed()
656 self.ensure_connection()
657
658 if autocommit:
659 self._set_autocommit(autocommit)
660 else:
661 with debug_transaction(self, "BEGIN"):
662 self._set_autocommit(autocommit)
663 self.autocommit = autocommit
664
665 if autocommit and self.run_commit_hooks_on_set_autocommit_on:
666 self.run_and_clear_commit_hooks()
667 self.run_commit_hooks_on_set_autocommit_on = False
668
669 def get_rollback(self) -> bool:
670 """Get the "needs rollback" flag -- for *advanced use* only."""
671 if not self.in_atomic_block:
672 raise TransactionManagementError(
673 "The rollback flag doesn't work outside of an 'atomic' block."
674 )
675 return self.needs_rollback
676
677 def set_rollback(self, rollback: bool) -> None:
678 """
679 Set or unset the "needs rollback" flag -- for *advanced use* only.
680 """
681 if not self.in_atomic_block:
682 raise TransactionManagementError(
683 "The rollback flag doesn't work outside of an 'atomic' block."
684 )
685 self.needs_rollback = rollback
686
687 def validate_no_atomic_block(self) -> None:
688 """Raise an error if an atomic block is active."""
689 if self.in_atomic_block:
690 raise TransactionManagementError(
691 "This is forbidden when an 'atomic' block is active."
692 )
693
694 def validate_no_broken_transaction(self) -> None:
695 if self.needs_rollback:
696 raise TransactionManagementError(
697 "An error occurred in the current transaction. You can't "
698 "execute queries until the end of the 'atomic' block."
699 ) from self.rollback_exc
700
701 # ##### Connection termination handling #####
702
703 def close_if_health_check_failed(self) -> None:
704 """Close existing connection if it fails a health check."""
705 if (
706 self.connection is None
707 or not self.health_check_enabled
708 or self.health_check_done
709 ):
710 return
711
712 if not self.is_usable():
713 self.close()
714 self.health_check_done = True
715
716 def close_if_unusable_or_obsolete(self) -> None:
717 """
718 Close the current connection if unrecoverable errors have occurred
719 or if it outlived its maximum age.
720 """
721 if self.connection is not None:
722 self.health_check_done = False
723 # If autocommit was not restored (e.g. a transaction was not
724 # properly closed), don't take chances, drop the connection.
725 if not self.get_autocommit():
726 self.close()
727 return
728
729 # If an exception other than DataError or IntegrityError occurred
730 # since the last commit / rollback, check if the connection works.
731 if self.errors_occurred:
732 if self.is_usable():
733 self.errors_occurred = False
734 self.health_check_done = True
735 else:
736 self.close()
737 return
738
739 if self.close_at is not None and time.monotonic() >= self.close_at:
740 self.close()
741 return
742
743 # ##### Miscellaneous #####
744
745 @cached_property
746 def wrap_database_errors(self) -> DatabaseErrorWrapper:
747 """
748 Context manager and decorator that re-throws backend-specific database
749 exceptions using Plain's common wrappers.
750 """
751 return DatabaseErrorWrapper(self)
752
753 def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
754 """Create a cursor without debug logging."""
755 return utils.CursorWrapper(cursor, self)
756
757 @contextmanager
758 def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
759 """
760 Context manager that ensures that a connection is established, and
761 if it opened one, closes it to avoid leaving a dangling connection.
762 This is useful for operations outside of the request-response cycle.
763
764 Provide a cursor: with self.temporary_connection() as cursor: ...
765 """
766 must_close = self.connection is None
767 try:
768 with self.cursor() as cursor:
769 yield cursor
770 finally:
771 if must_close:
772 self.close()
773
774 def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
775 """Return a new instance of the schema editor."""
776 return DatabaseSchemaEditor(self, *args, **kwargs)
777
778 def runshell(self, parameters: list[str]) -> None:
779 """Run an interactive psql shell."""
780 args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
781 env = {**os.environ, **env} if env else None
782 sigint_handler = signal.getsignal(signal.SIGINT)
783 try:
784 # Allow SIGINT to pass to psql to abort queries.
785 signal.signal(signal.SIGINT, signal.SIG_IGN)
786 subprocess.run(args, env=env, check=True)
787 finally:
788 # Restore the original SIGINT handler.
789 signal.signal(signal.SIGINT, sigint_handler)
790
791 def on_commit(self, func: Any, robust: bool = False) -> None:
792 if not callable(func):
793 raise TypeError("on_commit()'s callback must be a callable.")
794 if self.in_atomic_block:
795 # Transaction in progress; save for execution on commit.
796 self.run_on_commit.append((set(self.savepoint_ids), func, robust))
797 else:
798 # No transaction in progress; execute immediately.
799 if robust:
800 try:
801 func()
802 except Exception as e:
803 logger.error(
804 f"Error calling {func.__qualname__} in on_commit() (%s).",
805 e,
806 exc_info=True,
807 )
808 else:
809 func()
810
811 def run_and_clear_commit_hooks(self) -> None:
812 self.validate_no_atomic_block()
813 current_run_on_commit = self.run_on_commit
814 self.run_on_commit = []
815 while current_run_on_commit:
816 _, func, robust = current_run_on_commit.pop(0)
817 if robust:
818 try:
819 func()
820 except Exception as e:
821 logger.error(
822 f"Error calling {func.__qualname__} in on_commit() during "
823 f"transaction (%s).",
824 e,
825 exc_info=True,
826 )
827 else:
828 func()
829
830 @contextmanager
831 def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
832 """
833 Return a context manager under which the wrapper is applied to suitable
834 database query executions.
835 """
836 self.execute_wrappers.append(wrapper)
837 try:
838 yield
839 finally:
840 self.execute_wrappers.pop()
841
842 # ##### SQL generation methods that require connection state #####
843
844 def compose_sql(self, query: str, params: Any) -> str:
845 """
846 Compose a SQL query with parameters using psycopg's mogrify.
847
848 This requires an active connection because it uses the connection's
849 cursor to properly format parameters.
850 """
851 assert self.connection is not None
852 return ClientCursor(self.connection).mogrify(
853 psycopg_sql.SQL(cast(LiteralString, query)), params
854 )
855
856 def last_executed_query(
857 self,
858 cursor: utils.CursorWrapper,
859 sql: str,
860 params: Any,
861 ) -> str | None:
862 """
863 Return a string of the query last executed by the given cursor, with
864 placeholders replaced with actual values.
865 """
866 try:
867 return self.compose_sql(sql, params)
868 except errors.DataError:
869 return None
870
871 def unification_cast_sql(self, output_field: Field) -> str:
872 """
873 Given a field instance, return the SQL that casts the result of a union
874 to that type. The resulting string should contain a '%s' placeholder
875 for the expression being cast.
876 """
877 internal_type = output_field.get_internal_type()
878 if internal_type in (
879 "GenericIPAddressField",
880 "TimeField",
881 "UUIDField",
882 ):
883 # PostgreSQL will resolve a union as type 'text' if input types are
884 # 'unknown'.
885 # https://www.postgresql.org/docs/current/typeconv-union-case.html
886 # These fields cannot be implicitly cast back in the default
887 # PostgreSQL configuration so we need to explicitly cast them.
888 # We must also remove components of the type within brackets:
889 # varchar(255) -> varchar.
890 db_type = output_field.db_type()
891 if db_type:
892 return "CAST(%s AS {})".format(db_type.split("(")[0])
893 return "%s"
894
895 # ##### Introspection methods #####
896
897 def table_names(
898 self, cursor: CursorWrapper | None = None, include_views: bool = False
899 ) -> list[str]:
900 """
901 Return a list of names of all tables that exist in the database.
902 Sort the returned table list by Python's default sorting. Do NOT use
903 the database's ORDER BY here to avoid subtle differences in sorting
904 order between databases.
905 """
906
907 def get_names(cursor: CursorWrapper) -> list[str]:
908 return sorted(
909 ti.name
910 for ti in self.get_table_list(cursor)
911 if include_views or ti.type == "t"
912 )
913
914 if cursor is None:
915 with self.cursor() as cursor:
916 return get_names(cursor)
917 return get_names(cursor)
918
919 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
920 """
921 Return an unsorted list of TableInfo named tuples of all tables and
922 views that exist in the database.
923 """
924 cursor.execute(
925 """
926 SELECT
927 c.relname,
928 CASE
929 WHEN c.relispartition THEN 'p'
930 WHEN c.relkind IN ('m', 'v') THEN 'v'
931 ELSE 't'
932 END,
933 obj_description(c.oid, 'pg_class')
934 FROM pg_catalog.pg_class c
935 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
936 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
937 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
938 AND pg_catalog.pg_table_is_visible(c.oid)
939 """
940 )
941 return [
942 TableInfo(*row)
943 for row in cursor.fetchall()
944 if row[0] not in self.ignored_tables
945 ]
946
947 def plain_table_names(
948 self, only_existing: bool = False, include_views: bool = True
949 ) -> list[str]:
950 """
951 Return a list of all table names that have associated Plain models and
952 are in INSTALLED_PACKAGES.
953
954 If only_existing is True, include only the tables in the database.
955 """
956 tables = set()
957 for model in get_migratable_models():
958 tables.add(model.model_options.db_table)
959 tables.update(
960 f.m2m_db_table() for f in model._model_meta.local_many_to_many
961 )
962 tables = list(tables)
963 if only_existing:
964 existing_tables = set(self.table_names(include_views=include_views))
965 tables = [t for t in tables if t in existing_tables]
966 return tables
967
968 def get_sequences(
969 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
970 ) -> list[dict[str, Any]]:
971 """
972 Return a list of introspected sequences for table_name. Each sequence
973 is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
974 """
975 cursor.execute(
976 """
977 SELECT
978 s.relname AS sequence_name,
979 a.attname AS colname
980 FROM
981 pg_class s
982 JOIN pg_depend d ON d.objid = s.oid
983 AND d.classid = 'pg_class'::regclass
984 AND d.refclassid = 'pg_class'::regclass
985 JOIN pg_attribute a ON d.refobjid = a.attrelid
986 AND d.refobjsubid = a.attnum
987 JOIN pg_class tbl ON tbl.oid = d.refobjid
988 AND tbl.relname = %s
989 AND pg_catalog.pg_table_is_visible(tbl.oid)
990 WHERE
991 s.relkind = 'S';
992 """,
993 [table_name],
994 )
995 return [
996 {"name": row[0], "table": table_name, "column": row[1]}
997 for row in cursor.fetchall()
998 ]
999
1000 def get_constraints(
1001 self, cursor: CursorWrapper, table_name: str
1002 ) -> dict[str, dict[str, Any]]:
1003 """
1004 Retrieve any constraints or keys (unique, pk, fk, check, index) across
1005 one or more columns. Also retrieve the definition of expression-based
1006 indexes.
1007 """
1008 constraints: dict[str, dict[str, Any]] = {}
1009 # Loop over the key table, collecting things as constraints. The column
1010 # array must return column names in the same order in which they were
1011 # created.
1012 cursor.execute(
1013 """
1014 SELECT
1015 c.conname,
1016 array(
1017 SELECT attname
1018 FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
1019 JOIN pg_attribute AS ca ON cols.colid = ca.attnum
1020 WHERE ca.attrelid = c.conrelid
1021 ORDER BY cols.arridx
1022 ),
1023 c.contype,
1024 (SELECT fkc.relname || '.' || fka.attname
1025 FROM pg_attribute AS fka
1026 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
1027 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
1028 cl.reloptions
1029 FROM pg_constraint AS c
1030 JOIN pg_class AS cl ON c.conrelid = cl.oid
1031 WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
1032 """,
1033 [table_name],
1034 )
1035 for constraint, columns, kind, used_cols, options in cursor.fetchall():
1036 constraints[constraint] = {
1037 "columns": columns,
1038 "primary_key": kind == "p",
1039 "unique": kind in ["p", "u"],
1040 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
1041 "check": kind == "c",
1042 "index": False,
1043 "definition": None,
1044 "options": options,
1045 }
1046 # Now get indexes
1047 cursor.execute(
1048 """
1049 SELECT
1050 indexname,
1051 array_agg(attname ORDER BY arridx),
1052 indisunique,
1053 indisprimary,
1054 array_agg(ordering ORDER BY arridx),
1055 amname,
1056 exprdef,
1057 s2.attoptions
1058 FROM (
1059 SELECT
1060 c2.relname as indexname, idx.*, attr.attname, am.amname,
1061 CASE
1062 WHEN idx.indexprs IS NOT NULL THEN
1063 pg_get_indexdef(idx.indexrelid)
1064 END AS exprdef,
1065 CASE am.amname
1066 WHEN %s THEN
1067 CASE (option & 1)
1068 WHEN 1 THEN 'DESC' ELSE 'ASC'
1069 END
1070 END as ordering,
1071 c2.reloptions as attoptions
1072 FROM (
1073 SELECT *
1074 FROM
1075 pg_index i,
1076 unnest(i.indkey, i.indoption)
1077 WITH ORDINALITY koi(key, option, arridx)
1078 ) idx
1079 LEFT JOIN pg_class c ON idx.indrelid = c.oid
1080 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
1081 LEFT JOIN pg_am am ON c2.relam = am.oid
1082 LEFT JOIN
1083 pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
1084 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
1085 ) s2
1086 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
1087 """,
1088 [self.index_default_access_method, table_name],
1089 )
1090 for (
1091 index,
1092 columns,
1093 unique,
1094 primary,
1095 orders,
1096 type_,
1097 definition,
1098 options,
1099 ) in cursor.fetchall():
1100 if index not in constraints:
1101 basic_index = (
1102 type_ == self.index_default_access_method and options is None
1103 )
1104 constraints[index] = {
1105 "columns": columns if columns != [None] else [],
1106 "orders": orders if orders != [None] else [],
1107 "primary_key": primary,
1108 "unique": unique,
1109 "foreign_key": None,
1110 "check": False,
1111 "index": True,
1112 "type": Index.suffix if basic_index else type_,
1113 "definition": definition,
1114 "options": options,
1115 }
1116 return constraints
1117
1118 # ##### Test database creation methods (merged from DatabaseCreation) #####
1119
1120 def _log(self, msg: str) -> None:
1121 sys.stderr.write(msg + os.linesep)
1122
1123 def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
1124 """
1125 Create a test database, prompting the user for confirmation if the
1126 database already exists. Return the name of the test database created.
1127
1128 If prefix is provided, it will be prepended to the database name
1129 to isolate it from other test databases.
1130 """
1131 from plain.models.cli.migrations import apply
1132
1133 test_database_name = self._get_test_db_name(prefix)
1134
1135 if verbosity >= 1:
1136 self._log(f"Creating test database '{test_database_name}'...")
1137
1138 self._create_test_db(
1139 test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
1140 )
1141
1142 self.close()
1143 settings.POSTGRES_DATABASE = test_database_name
1144 self.settings_dict["DATABASE"] = test_database_name
1145
1146 apply.callback(
1147 package_label=None,
1148 migration_name=None,
1149 fake=False,
1150 plan=False,
1151 check_unapplied=False,
1152 backup=False,
1153 no_input=True,
1154 atomic_batch=False, # No need for atomic batch when creating test database
1155 quiet=verbosity < 2, # Show migration output when verbosity is 2+
1156 )
1157
1158 # Ensure a connection for the side effect of initializing the test database.
1159 self.ensure_connection()
1160
1161 return test_database_name
1162
1163 def _get_test_db_name(self, prefix: str = "") -> str:
1164 """
1165 Internal implementation - return the name of the test DB that will be
1166 created. Only useful when called from create_test_db() and
1167 _create_test_db() and when no external munging is done with the 'DATABASE'
1168 settings.
1169
1170 If prefix is provided, it will be prepended to the database name.
1171 """
1172 # Determine the base name: explicit TEST.DATABASE overrides base DATABASE.
1173 base_name = (
1174 self.settings_dict["TEST"]["DATABASE"] or self.settings_dict["DATABASE"]
1175 )
1176 if prefix:
1177 return f"{prefix}_{base_name}"
1178 if self.settings_dict["TEST"]["DATABASE"]:
1179 return self.settings_dict["TEST"]["DATABASE"]
1180 name = self.settings_dict["DATABASE"]
1181 if name is None:
1182 raise ValueError("POSTGRES_DATABASE must be set")
1183 return TEST_DATABASE_PREFIX + name
1184
1185 def _get_database_create_suffix(
1186 self, encoding: str | None = None, template: str | None = None
1187 ) -> str:
1188 """Return PostgreSQL-specific CREATE DATABASE suffix."""
1189 suffix = ""
1190 if encoding:
1191 suffix += f" ENCODING '{encoding}'"
1192 if template:
1193 suffix += f" TEMPLATE {quote_name(template)}"
1194 return suffix and "WITH" + suffix
1195
1196 def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
1197 try:
1198 cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
1199 except Exception as e:
1200 cause = e.__cause__
1201 if cause and not isinstance(cause, errors.DuplicateDatabase):
1202 # All errors except "database already exists" cancel tests.
1203 self._log(f"Got an error creating the test database: {e}")
1204 sys.exit(2)
1205 else:
1206 raise
1207
1208 def _create_test_db(
1209 self, *, test_database_name: str, verbosity: int, autoclobber: bool
1210 ) -> str:
1211 """
1212 Internal implementation - create the test db tables.
1213 """
1214 test_db_params = {
1215 "dbname": quote_name(test_database_name),
1216 "suffix": self.sql_table_creation_suffix(),
1217 }
1218 # Create the test database and connect to it.
1219 with self._nodb_cursor() as cursor:
1220 try:
1221 self._execute_create_test_db(cursor, test_db_params)
1222 except Exception as e:
1223 self._log(f"Got an error creating the test database: {e}")
1224 if not autoclobber:
1225 confirm = input(
1226 "Type 'yes' if you would like to try deleting the test "
1227 f"database '{test_database_name}', or 'no' to cancel: "
1228 )
1229 if autoclobber or confirm == "yes":
1230 try:
1231 if verbosity >= 1:
1232 self._log(
1233 f"Destroying old test database '{test_database_name}'..."
1234 )
1235 cursor.execute(
1236 "DROP DATABASE {dbname}".format(**test_db_params)
1237 )
1238 self._execute_create_test_db(cursor, test_db_params)
1239 except Exception as e:
1240 self._log(f"Got an error recreating the test database: {e}")
1241 sys.exit(2)
1242 else:
1243 self._log("Tests cancelled.")
1244 sys.exit(1)
1245
1246 return test_database_name
1247
1248 def destroy_test_db(
1249 self, old_database_name: str | None = None, verbosity: int = 1
1250 ) -> None:
1251 """
1252 Destroy a test database, prompting the user for confirmation if the
1253 database already exists.
1254 """
1255 self.close()
1256
1257 test_database_name = self.settings_dict["DATABASE"]
1258 if test_database_name is None:
1259 raise ValueError("Test POSTGRES_DATABASE must be set")
1260
1261 if verbosity >= 1:
1262 self._log(f"Destroying test database '{test_database_name}'...")
1263 self._destroy_test_db(test_database_name, verbosity)
1264
1265 # Restore the original database name
1266 if old_database_name is not None:
1267 settings.POSTGRES_DATABASE = old_database_name
1268 self.settings_dict["DATABASE"] = old_database_name
1269
1270 def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
1271 """
1272 Internal implementation - remove the test db tables.
1273 """
1274 # Remove the test database to clean up after
1275 # ourselves. Connect to the previous database (not the test database)
1276 # to do so, because it's not allowed to delete a database while being
1277 # connected to it.
1278 with self._nodb_cursor() as cursor:
1279 cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
1280
1281 def sql_table_creation_suffix(self) -> str:
1282 """
1283 SQL to append to the end of the test table creation statements.
1284 """
1285 test_settings = self.settings_dict["TEST"]
1286 return self._get_database_create_suffix(
1287 encoding=test_settings.get("CHARSET"),
1288 template=test_settings.get("TEMPLATE"),
1289 )
1290
1291
1292class CursorMixin:
1293 """
1294 A subclass of psycopg cursor implementing callproc.
1295 """
1296
1297 def callproc(
1298 self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
1299 ) -> list[Any] | None:
1300 if not isinstance(name, psycopg_sql.Identifier):
1301 name = psycopg_sql.Identifier(name)
1302
1303 qparts: list[psycopg_sql.Composable] = [
1304 psycopg_sql.SQL("SELECT * FROM "),
1305 name,
1306 psycopg_sql.SQL("("),
1307 ]
1308 if args:
1309 for item in args:
1310 qparts.append(psycopg_sql.Literal(item))
1311 qparts.append(psycopg_sql.SQL(","))
1312 del qparts[-1]
1313
1314 qparts.append(psycopg_sql.SQL(")"))
1315 stmt = psycopg_sql.Composed(qparts)
1316 self.execute(stmt) # type: ignore[attr-defined]
1317 return args
1318
1319
1320class ServerBindingCursor(CursorMixin, Database.Cursor):
1321 pass
1322
1323
1324class Cursor(CursorMixin, Database.ClientCursor):
1325 pass
1326
1327
1328class CursorDebugWrapper(BaseCursorDebugWrapper):
1329 def copy(self, statement: Any) -> Any:
1330 with self.debug_sql(statement):
1331 return self.cursor.copy(statement)