1from __future__ import annotations
2
3import _thread
4import datetime
5import logging
6import os
7import signal
8import subprocess
9import sys
10import threading
11import time
12import warnings
13import zoneinfo
14from collections import deque
15from collections.abc import Generator, Sequence
16from contextlib import contextmanager
17from functools import cached_property, lru_cache
18from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
19
20import psycopg as Database
21from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
22from psycopg import sql as psycopg_sql
23from psycopg.abc import Buffer, PyFormat
24from psycopg.postgres import types as pg_types
25from psycopg.pq import Format
26from psycopg.types.datetime import TimestamptzLoader
27from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
28from psycopg.types.string import TextLoader
29
30from plain.exceptions import ImproperlyConfigured
31from plain.models.db import (
32 DatabaseError,
33 DatabaseErrorWrapper,
34 NotSupportedError,
35 db_connection,
36)
37from plain.models.indexes import Index
38from plain.models.postgres import utils
39from plain.models.postgres.schema import DatabaseSchemaEditor
40from plain.models.postgres.sql import MAX_NAME_LENGTH, quote_name
41from plain.models.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
42from plain.models.postgres.utils import CursorWrapper, debug_transaction
43from plain.models.transaction import TransactionManagementError
44from plain.runtime import settings
45
46if TYPE_CHECKING:
47 from psycopg import Connection as PsycopgConnection
48
49 from plain.models.connections import DatabaseConfig
50 from plain.models.fields import Field
51
52RAN_DB_VERSION_CHECK = False
53
54logger = logging.getLogger("plain.models.postgres")
55
56# The prefix to put on the default database name when creating
57# the test database.
58TEST_DATABASE_PREFIX = "test_"
59
60
61def get_migratable_models() -> Generator[Any, None, None]:
62 """Return all models that should be included in migrations."""
63 from plain.models import models_registry
64 from plain.packages import packages_registry
65
66 return (
67 model
68 for package_config in packages_registry.get_package_configs()
69 for model in models_registry.get_models(
70 package_label=package_config.package_label
71 )
72 )
73
74
75class TableInfo(NamedTuple):
76 """Structure returned by DatabaseWrapper.get_table_list()."""
77
78 name: str
79 type: str
80 comment: str | None
81
82
83# Type OIDs
84TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
85TSRANGE_OID = pg_types["tsrange"].oid
86TSTZRANGE_OID = pg_types["tstzrange"].oid
87
88
89class BaseTzLoader(TimestamptzLoader):
90 """
91 Load a PostgreSQL timestamptz using a specific timezone.
92 The timezone can be None too, in which case it will be chopped.
93 """
94
95 timezone: datetime.tzinfo | None = None
96
97 def load(self, data: Buffer) -> datetime.datetime:
98 res = super().load(data)
99 return res.replace(tzinfo=self.timezone)
100
101
102def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
103 class SpecificTzLoader(BaseTzLoader):
104 timezone = tz
105
106 context.adapters.register_loader("timestamptz", SpecificTzLoader)
107
108
109class PlainRangeDumper(RangeDumper):
110 """A Range dumper customized for Plain."""
111
112 def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
113 dumper = super().upgrade(obj, format)
114 if dumper is not self and dumper.oid == TSRANGE_OID:
115 dumper.oid = TSTZRANGE_OID
116 return dumper
117
118
119@lru_cache
120def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
121 ctx = adapt.AdaptersMap(adapters)
122 # No-op JSON loader to avoid psycopg3 round trips
123 ctx.register_loader("jsonb", TextLoader)
124 # Treat inet/cidr as text
125 ctx.register_loader("inet", TextLoader)
126 ctx.register_loader("cidr", TextLoader)
127 ctx.register_dumper(Range, PlainRangeDumper)
128 register_tzloader(timezone, ctx)
129 return ctx
130
131
132def _psql_settings_to_cmd_args_env(
133 settings_dict: DatabaseConfig, parameters: list[str]
134) -> tuple[list[str], dict[str, str] | None]:
135 """Build psql command-line arguments from database settings."""
136 args = ["psql"]
137 options = settings_dict.get("OPTIONS", {})
138
139 if user := settings_dict.get("USER"):
140 args += ["-U", user]
141 if host := settings_dict.get("HOST"):
142 args += ["-h", host]
143 if port := settings_dict.get("PORT"):
144 args += ["-p", str(port)]
145 args.extend(parameters)
146 args += [settings_dict.get("DATABASE") or "postgres"]
147
148 env: dict[str, str] = {}
149 if password := settings_dict.get("PASSWORD"):
150 env["PGPASSWORD"] = str(password)
151
152 # Map OPTIONS keys to their corresponding environment variables.
153 option_env_vars = {
154 "passfile": "PGPASSFILE",
155 "sslmode": "PGSSLMODE",
156 "sslrootcert": "PGSSLROOTCERT",
157 "sslcert": "PGSSLCERT",
158 "sslkey": "PGSSLKEY",
159 }
160 for option_key, env_var in option_env_vars.items():
161 if value := options.get(option_key):
162 env[env_var] = str(value)
163
164 return args, (env or None)
165
166
167class DatabaseWrapper:
168 """
169 PostgreSQL database connection wrapper.
170
171 This is the only database backend supported by Plain.
172 """
173
174 queries_limit: int = 9000
175 executable_name: str = "psql"
176
177 index_default_access_method = "btree"
178 ignored_tables: list[str] = []
179
180 # PostgreSQL backend-specific attributes.
181 _named_cursor_idx = 0
182
183 def __init__(self, settings_dict: DatabaseConfig):
184 # Connection related attributes.
185 # The underlying database connection (from the database library, not a wrapper).
186 self.connection: PsycopgConnection[Any] | None = None
187 # `settings_dict` should be a dictionary containing keys such as
188 # DATABASE, USER, etc. It's called `settings_dict` instead of `settings`
189 # to disambiguate it from Plain settings modules.
190 self.settings_dict: DatabaseConfig = settings_dict
191 # Query logging in debug mode or when explicitly enabled.
192 self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
193 self.force_debug_cursor: bool = False
194
195 # Transaction related attributes.
196 # Tracks if the connection is in autocommit mode. Per PEP 249, by
197 # default, it isn't.
198 self.autocommit: bool = False
199 # Tracks if the connection is in a transaction managed by 'atomic'.
200 self.in_atomic_block: bool = False
201 # Increment to generate unique savepoint ids.
202 self.savepoint_state: int = 0
203 # List of savepoints created by 'atomic'.
204 self.savepoint_ids: list[str] = []
205 # Stack of active 'atomic' blocks.
206 self.atomic_blocks: list[Any] = []
207 # Tracks if the transaction should be rolled back to the next
208 # available savepoint because of an exception in an inner block.
209 self.needs_rollback: bool = False
210 self.rollback_exc: Exception | None = None
211
212 # Connection termination related attributes.
213 self.close_at: float | None = None
214 self.closed_in_transaction: bool = False
215 self.errors_occurred: bool = False
216 self.health_check_enabled: bool = False
217 self.health_check_done: bool = False
218
219 # Thread-safety related attributes.
220 self._thread_sharing_lock: threading.Lock = threading.Lock()
221 self._thread_sharing_count: int = 0
222 self._thread_ident: int = _thread.get_ident()
223
224 # A list of no-argument functions to run when the transaction commits.
225 # Each entry is an (sids, func, robust) tuple, where sids is a set of
226 # the active savepoint IDs when this function was registered and robust
227 # specifies whether it's allowed for the function to fail.
228 self.run_on_commit: list[tuple[set[str], Any, bool]] = []
229
230 # Should we run the on-commit hooks the next time set_autocommit(True)
231 # is called?
232 self.run_commit_hooks_on_set_autocommit_on: bool = False
233
234 # A stack of wrappers to be invoked around execute()/executemany()
235 # calls. Each entry is a function taking five arguments: execute, sql,
236 # params, many, and context. It's the function's responsibility to
237 # call execute(sql, params, many, context).
238 self.execute_wrappers: list[Any] = []
239
240 def __repr__(self) -> str:
241 return f"<{self.__class__.__qualname__} vendor='postgresql'>"
242
243 @cached_property
244 def timezone(self) -> datetime.tzinfo:
245 """
246 Return a tzinfo of the database connection time zone.
247
248 When a datetime is read from the database, it is returned in this time
249 zone. Since PostgreSQL supports time zones, it doesn't matter which
250 time zone Plain uses, as long as aware datetimes are used everywhere.
251 Other users connecting to the database can choose their own time zone.
252 """
253 if self.settings_dict["TIME_ZONE"] is None:
254 return datetime.UTC
255 return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
256
257 @cached_property
258 def timezone_name(self) -> str:
259 """
260 Name of the time zone of the database connection.
261 """
262 if self.settings_dict["TIME_ZONE"] is None:
263 return "UTC"
264 return self.settings_dict["TIME_ZONE"]
265
266 @property
267 def queries_logged(self) -> bool:
268 return self.force_debug_cursor or settings.DEBUG
269
270 @property
271 def queries(self) -> list[dict[str, Any]]:
272 if len(self.queries_log) == self.queries_log.maxlen:
273 warnings.warn(
274 f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
275 "will be returned."
276 )
277 return list(self.queries_log)
278
279 def check_database_version_supported(self) -> None:
280 """
281 Raise an error if the database version isn't supported by this
282 version of Plain. PostgreSQL 12+ is required.
283 """
284 major, minor = divmod(self.pg_version, 10000)
285 if major < 12:
286 raise NotSupportedError(
287 f"PostgreSQL 12 or later is required (found {major}.{minor})."
288 )
289
290 # ##### Connection and cursor methods #####
291
292 def get_connection_params(self) -> dict[str, Any]:
293 """Return a dict of parameters suitable for get_new_connection."""
294 settings_dict = self.settings_dict
295 options = settings_dict.get("OPTIONS", {})
296 db_name = settings_dict.get("DATABASE")
297 if db_name == "":
298 raise ImproperlyConfigured(
299 "PostgreSQL database is not configured. "
300 "Set DATABASE_URL or the POSTGRES_DATABASE setting."
301 )
302 if len(db_name or "") > MAX_NAME_LENGTH:
303 raise ImproperlyConfigured(
304 "The database name '%s' (%d characters) is longer than " # noqa: UP031
305 "PostgreSQL's limit of %d characters. Supply a shorter "
306 "POSTGRES_DATABASE setting."
307 % (
308 db_name,
309 len(db_name or ""),
310 MAX_NAME_LENGTH,
311 )
312 )
313 if db_name is None:
314 # None is used to connect to the default 'postgres' db.
315 db_name = "postgres"
316 conn_params: dict[str, Any] = {
317 "dbname": db_name,
318 **options,
319 }
320
321 conn_params.pop("assume_role", None)
322 conn_params.pop("isolation_level", None)
323 conn_params.pop("server_side_binding", None)
324 if settings_dict["USER"]:
325 conn_params["user"] = settings_dict["USER"]
326 if settings_dict["PASSWORD"]:
327 conn_params["password"] = settings_dict["PASSWORD"]
328 if settings_dict["HOST"]:
329 conn_params["host"] = settings_dict["HOST"]
330 if settings_dict["PORT"]:
331 conn_params["port"] = settings_dict["PORT"]
332 conn_params["context"] = get_adapters_template(self.timezone)
333 # Disable prepared statements by default to keep connection poolers
334 # working. Can be reenabled via OPTIONS in the settings dict.
335 conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
336 return conn_params
337
338 def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
339 """Open a connection to the database."""
340 # self.isolation_level must be set:
341 # - after connecting to the database in order to obtain the database's
342 # default when no value is explicitly specified in options.
343 # - before calling _set_autocommit() because if autocommit is on, that
344 # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
345 options = self.settings_dict.get("OPTIONS", {})
346 set_isolation_level = False
347 try:
348 isolation_level_value = options["isolation_level"]
349 except KeyError:
350 self.isolation_level = IsolationLevel.READ_COMMITTED
351 else:
352 # Set the isolation level to the value from OPTIONS.
353 try:
354 self.isolation_level = IsolationLevel(isolation_level_value)
355 set_isolation_level = True
356 except ValueError:
357 raise ImproperlyConfigured(
358 f"Invalid transaction isolation level {isolation_level_value} "
359 f"specified. Use one of the psycopg.IsolationLevel values."
360 )
361 connection = Database.connect(**conn_params)
362 if set_isolation_level:
363 connection.isolation_level = self.isolation_level
364 # Use server-side binding cursor if requested, otherwise standard cursor
365 connection.cursor_factory = (
366 ServerBindingCursor
367 if options.get("server_side_binding") is True
368 else Cursor
369 )
370 return connection
371
372 def ensure_timezone(self) -> bool:
373 """
374 Ensure the connection's timezone is set to `self.timezone_name` and
375 return whether it changed or not.
376 """
377 if self.connection is None:
378 return False
379 conn_timezone_name = self.connection.info.parameter_status("TimeZone")
380 timezone_name = self.timezone_name
381 if timezone_name and conn_timezone_name != timezone_name:
382 with self.connection.cursor() as cursor:
383 cursor.execute(
384 "SELECT set_config('TimeZone', %s, false)", [timezone_name]
385 )
386 return True
387 return False
388
389 def ensure_role(self) -> bool:
390 if self.connection is None:
391 return False
392 if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
393 with self.connection.cursor() as cursor:
394 sql_str = self.compose_sql("SET ROLE %s", [new_role])
395 cursor.execute(sql_str) # type: ignore[arg-type]
396 return True
397 return False
398
399 def init_connection_state(self) -> None:
400 """Initialize the database connection settings."""
401 global RAN_DB_VERSION_CHECK
402 if not RAN_DB_VERSION_CHECK:
403 self.check_database_version_supported()
404 RAN_DB_VERSION_CHECK = True
405
406 self.ensure_timezone()
407 # Set the role on the connection. This is useful if the credential used
408 # to login is not the same as the role that owns database resources. As
409 # can be the case when using temporary or ephemeral credentials.
410 self.ensure_role()
411
412 def create_cursor(self, name: str | None = None) -> Any:
413 """Create a cursor. Assume that a connection is established."""
414 assert self.connection is not None
415 if name:
416 # In autocommit mode, the cursor will be used outside of a
417 # transaction, hence use a holdable cursor.
418 cursor = self.connection.cursor(
419 name, scrollable=False, withhold=self.connection.autocommit
420 )
421 else:
422 cursor = self.connection.cursor()
423
424 # Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
425 tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
426 if self.timezone != tzloader.timezone: # type: ignore[union-attr]
427 register_tzloader(self.timezone, cursor)
428 return cursor
429
430 def chunked_cursor(self) -> utils.CursorWrapper:
431 """
432 Return a server-side cursor that avoids caching results in memory.
433 """
434 self._named_cursor_idx += 1
435 # Get the current async task
436 # Note that right now this is behind @async_unsafe, so this is
437 # unreachable, but in future we'll start loosening this restriction.
438 # For now, it's here so that every use of "threading" is
439 # also async-compatible.
440 task_ident = "sync"
441 # Use that and the thread ident to get a unique name
442 return self._cursor(
443 name="_plain_curs_%d_%s_%d" # noqa: UP031
444 % (
445 # Avoid reusing name in other threads / tasks
446 threading.current_thread().ident,
447 task_ident,
448 self._named_cursor_idx,
449 )
450 )
451
452 def _set_autocommit(self, autocommit: bool) -> None:
453 """Backend-specific implementation to enable or disable autocommit."""
454 assert self.connection is not None
455 with self.wrap_database_errors:
456 self.connection.autocommit = autocommit
457
458 def check_constraints(self, table_names: list[str] | None = None) -> None:
459 """
460 Check constraints by setting them to immediate. Return them to deferred
461 afterward.
462 """
463 with self.cursor() as cursor:
464 cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
465 cursor.execute("SET CONSTRAINTS ALL DEFERRED")
466
467 def is_usable(self) -> bool:
468 """
469 Test if the database connection is usable.
470
471 This method may assume that self.connection is not None.
472
473 Actual implementations should take care not to raise exceptions
474 as that may prevent Plain from recycling unusable connections.
475 """
476 assert self.connection is not None
477 try:
478 # Use a psycopg cursor directly, bypassing Plain's utilities.
479 with self.connection.cursor() as cursor:
480 cursor.execute("SELECT 1")
481 except Database.Error:
482 return False
483 else:
484 return True
485
486 @contextmanager
487 def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
488 """
489 Return a cursor from an alternative connection to be used when there is
490 no need to access the main database, specifically for test db
491 creation/deletion. This also prevents the production database from
492 being exposed to potential child threads while (or after) the test
493 database is destroyed. Refs #10868, #17786, #16969.
494 """
495 cursor = None
496 try:
497 conn = self.__class__({**self.settings_dict, "DATABASE": None})
498 try:
499 with conn.cursor() as cursor:
500 yield cursor
501 finally:
502 conn.close()
503 except (Database.DatabaseError, DatabaseError):
504 if cursor is not None:
505 raise
506 warnings.warn(
507 "Normally Plain will use a connection to the 'postgres' database "
508 "to avoid running initialization queries against the production "
509 "database when it's not needed (for example, when running tests). "
510 "Plain was unable to create a connection to the 'postgres' database "
511 "and will use the first PostgreSQL database instead.",
512 RuntimeWarning,
513 )
514 conn = self.__class__(
515 {
516 **self.settings_dict,
517 "DATABASE": db_connection.settings_dict["DATABASE"],
518 },
519 )
520 try:
521 with conn.cursor() as cursor:
522 yield cursor
523 finally:
524 conn.close()
525
526 @cached_property
527 def pg_version(self) -> int:
528 with self.temporary_connection():
529 assert self.connection is not None
530 return self.connection.info.server_version
531
532 def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
533 return CursorDebugWrapper(cursor, self)
534
535 # ##### Connection lifecycle #####
536
537 def connect(self) -> None:
538 """Connect to the database. Assume that the connection is closed."""
539 # In case the previous connection was closed while in an atomic block
540 self.in_atomic_block = False
541 self.savepoint_ids = []
542 self.atomic_blocks = []
543 self.needs_rollback = False
544 # Reset parameters defining when to close/health-check the connection.
545 self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
546 max_age = self.settings_dict["CONN_MAX_AGE"]
547 self.close_at = None if max_age is None else time.monotonic() + max_age
548 self.closed_in_transaction = False
549 self.errors_occurred = False
550 # New connections are healthy.
551 self.health_check_done = True
552 # Establish the connection
553 conn_params = self.get_connection_params()
554 self.connection = self.get_new_connection(conn_params)
555 self.set_autocommit(True)
556 self.init_connection_state()
557
558 self.run_on_commit = []
559
560 def ensure_connection(self) -> None:
561 """Guarantee that a connection to the database is established."""
562 if self.connection is None:
563 with self.wrap_database_errors:
564 self.connect()
565
566 # ##### PEP-249 connection method wrappers #####
567
568 def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
569 """
570 Validate the connection is usable and perform database cursor wrapping.
571 """
572 self.validate_thread_sharing()
573 if self.queries_logged:
574 wrapped_cursor = self.make_debug_cursor(cursor)
575 else:
576 wrapped_cursor = self.make_cursor(cursor)
577 return wrapped_cursor
578
579 def _cursor(self, name: str | None = None) -> utils.CursorWrapper:
580 self.close_if_health_check_failed()
581 self.ensure_connection()
582 with self.wrap_database_errors:
583 return self._prepare_cursor(self.create_cursor(name))
584
585 def _commit(self) -> None:
586 if self.connection is not None:
587 with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
588 return self.connection.commit()
589
590 def _rollback(self) -> None:
591 if self.connection is not None:
592 with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
593 return self.connection.rollback()
594
595 def _close(self) -> None:
596 if self.connection is not None:
597 with self.wrap_database_errors:
598 return self.connection.close()
599
600 # ##### Generic wrappers for PEP-249 connection methods #####
601
602 def cursor(self) -> utils.CursorWrapper:
603 """Create a cursor, opening a connection if necessary."""
604 return self._cursor()
605
606 def commit(self) -> None:
607 """Commit a transaction and reset the dirty flag."""
608 self.validate_thread_sharing()
609 self.validate_no_atomic_block()
610 self._commit()
611 # A successful commit means that the database connection works.
612 self.errors_occurred = False
613 self.run_commit_hooks_on_set_autocommit_on = True
614
615 def rollback(self) -> None:
616 """Roll back a transaction and reset the dirty flag."""
617 self.validate_thread_sharing()
618 self.validate_no_atomic_block()
619 self._rollback()
620 # A successful rollback means that the database connection works.
621 self.errors_occurred = False
622 self.needs_rollback = False
623 self.run_on_commit = []
624
625 def close(self) -> None:
626 """Close the connection to the database."""
627 self.validate_thread_sharing()
628 self.run_on_commit = []
629
630 # Don't call validate_no_atomic_block() to avoid making it difficult
631 # to get rid of a connection in an invalid state. The next connect()
632 # will reset the transaction state anyway.
633 if self.closed_in_transaction or self.connection is None:
634 return
635 try:
636 self._close()
637 finally:
638 if self.in_atomic_block:
639 self.closed_in_transaction = True
640 self.needs_rollback = True
641 else:
642 self.connection = None
643
644 # ##### Savepoint management #####
645
646 def _savepoint(self, sid: str) -> None:
647 with self.cursor() as cursor:
648 cursor.execute(f"SAVEPOINT {quote_name(sid)}")
649
650 def _savepoint_rollback(self, sid: str) -> None:
651 with self.cursor() as cursor:
652 cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
653
654 def _savepoint_commit(self, sid: str) -> None:
655 with self.cursor() as cursor:
656 cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
657
658 # ##### Generic savepoint management methods #####
659
660 def savepoint(self) -> str | None:
661 """
662 Create a savepoint inside the current transaction. Return an
663 identifier for the savepoint that will be used for the subsequent
664 rollback or commit. Return None if in autocommit mode (no transaction).
665 """
666 if self.get_autocommit():
667 return None
668
669 thread_ident = _thread.get_ident()
670 tid = str(thread_ident).replace("-", "")
671
672 self.savepoint_state += 1
673 sid = "s%s_x%d" % (tid, self.savepoint_state) # noqa: UP031
674
675 self.validate_thread_sharing()
676 self._savepoint(sid)
677
678 return sid
679
680 def savepoint_rollback(self, sid: str) -> None:
681 """
682 Roll back to a savepoint. Do nothing if in autocommit mode.
683 """
684 if self.get_autocommit():
685 return
686
687 self.validate_thread_sharing()
688 self._savepoint_rollback(sid)
689
690 # Remove any callbacks registered while this savepoint was active.
691 self.run_on_commit = [
692 (sids, func, robust)
693 for (sids, func, robust) in self.run_on_commit
694 if sid not in sids
695 ]
696
697 def savepoint_commit(self, sid: str) -> None:
698 """
699 Release a savepoint. Do nothing if in autocommit mode.
700 """
701 if self.get_autocommit():
702 return
703
704 self.validate_thread_sharing()
705 self._savepoint_commit(sid)
706
707 def clean_savepoints(self) -> None:
708 """
709 Reset the counter used to generate unique savepoint ids in this thread.
710 """
711 self.savepoint_state = 0
712
713 # ##### Generic transaction management methods #####
714
715 def get_autocommit(self) -> bool:
716 """Get the autocommit state."""
717 self.ensure_connection()
718 return self.autocommit
719
720 def set_autocommit(self, autocommit: bool) -> None:
721 """
722 Enable or disable autocommit.
723
724 Used internally by atomic() to manage transactions. Don't call this
725 directly โ use atomic() instead.
726 """
727 self.validate_no_atomic_block()
728 self.close_if_health_check_failed()
729 self.ensure_connection()
730
731 if autocommit:
732 self._set_autocommit(autocommit)
733 else:
734 with debug_transaction(self, "BEGIN"):
735 self._set_autocommit(autocommit)
736 self.autocommit = autocommit
737
738 if autocommit and self.run_commit_hooks_on_set_autocommit_on:
739 self.run_and_clear_commit_hooks()
740 self.run_commit_hooks_on_set_autocommit_on = False
741
742 def get_rollback(self) -> bool:
743 """Get the "needs rollback" flag -- for *advanced use* only."""
744 if not self.in_atomic_block:
745 raise TransactionManagementError(
746 "The rollback flag doesn't work outside of an 'atomic' block."
747 )
748 return self.needs_rollback
749
750 def set_rollback(self, rollback: bool) -> None:
751 """
752 Set or unset the "needs rollback" flag -- for *advanced use* only.
753 """
754 if not self.in_atomic_block:
755 raise TransactionManagementError(
756 "The rollback flag doesn't work outside of an 'atomic' block."
757 )
758 self.needs_rollback = rollback
759
760 def validate_no_atomic_block(self) -> None:
761 """Raise an error if an atomic block is active."""
762 if self.in_atomic_block:
763 raise TransactionManagementError(
764 "This is forbidden when an 'atomic' block is active."
765 )
766
767 def validate_no_broken_transaction(self) -> None:
768 if self.needs_rollback:
769 raise TransactionManagementError(
770 "An error occurred in the current transaction. You can't "
771 "execute queries until the end of the 'atomic' block."
772 ) from self.rollback_exc
773
774 # ##### Connection termination handling #####
775
776 def close_if_health_check_failed(self) -> None:
777 """Close existing connection if it fails a health check."""
778 if (
779 self.connection is None
780 or not self.health_check_enabled
781 or self.health_check_done
782 ):
783 return
784
785 if not self.is_usable():
786 self.close()
787 self.health_check_done = True
788
789 def close_if_unusable_or_obsolete(self) -> None:
790 """
791 Close the current connection if unrecoverable errors have occurred
792 or if it outlived its maximum age.
793 """
794 if self.connection is not None:
795 self.health_check_done = False
796 # If autocommit was not restored (e.g. a transaction was not
797 # properly closed), don't take chances, drop the connection.
798 if not self.get_autocommit():
799 self.close()
800 return
801
802 # If an exception other than DataError or IntegrityError occurred
803 # since the last commit / rollback, check if the connection works.
804 if self.errors_occurred:
805 if self.is_usable():
806 self.errors_occurred = False
807 self.health_check_done = True
808 else:
809 self.close()
810 return
811
812 if self.close_at is not None and time.monotonic() >= self.close_at:
813 self.close()
814 return
815
816 # ##### Thread safety handling #####
817
818 @property
819 def allow_thread_sharing(self) -> bool:
820 with self._thread_sharing_lock:
821 return self._thread_sharing_count > 0
822
823 def validate_thread_sharing(self) -> None:
824 """
825 Validate that the connection isn't accessed by another thread than the
826 one which originally created it. Raise an exception if the validation
827 fails.
828 """
829 if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
830 raise DatabaseError(
831 "DatabaseWrapper objects created in a "
832 "thread can only be used in that same thread. The connection "
833 f"was created in thread id {self._thread_ident} and this is "
834 f"thread id {_thread.get_ident()}."
835 )
836
837 # ##### Miscellaneous #####
838
839 @cached_property
840 def wrap_database_errors(self) -> DatabaseErrorWrapper:
841 """
842 Context manager and decorator that re-throws backend-specific database
843 exceptions using Plain's common wrappers.
844 """
845 return DatabaseErrorWrapper(self)
846
847 def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
848 """Create a cursor without debug logging."""
849 return utils.CursorWrapper(cursor, self)
850
851 @contextmanager
852 def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
853 """
854 Context manager that ensures that a connection is established, and
855 if it opened one, closes it to avoid leaving a dangling connection.
856 This is useful for operations outside of the request-response cycle.
857
858 Provide a cursor: with self.temporary_connection() as cursor: ...
859 """
860 must_close = self.connection is None
861 try:
862 with self.cursor() as cursor:
863 yield cursor
864 finally:
865 if must_close:
866 self.close()
867
868 def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
869 """Return a new instance of the schema editor."""
870 return DatabaseSchemaEditor(self, *args, **kwargs)
871
872 def runshell(self, parameters: list[str]) -> None:
873 """Run an interactive psql shell."""
874 args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
875 env = {**os.environ, **env} if env else None
876 sigint_handler = signal.getsignal(signal.SIGINT)
877 try:
878 # Allow SIGINT to pass to psql to abort queries.
879 signal.signal(signal.SIGINT, signal.SIG_IGN)
880 subprocess.run(args, env=env, check=True)
881 finally:
882 # Restore the original SIGINT handler.
883 signal.signal(signal.SIGINT, sigint_handler)
884
885 def on_commit(self, func: Any, robust: bool = False) -> None:
886 if not callable(func):
887 raise TypeError("on_commit()'s callback must be a callable.")
888 if self.in_atomic_block:
889 # Transaction in progress; save for execution on commit.
890 self.run_on_commit.append((set(self.savepoint_ids), func, robust))
891 else:
892 # No transaction in progress; execute immediately.
893 if robust:
894 try:
895 func()
896 except Exception as e:
897 logger.error(
898 f"Error calling {func.__qualname__} in on_commit() (%s).",
899 e,
900 exc_info=True,
901 )
902 else:
903 func()
904
905 def run_and_clear_commit_hooks(self) -> None:
906 self.validate_no_atomic_block()
907 current_run_on_commit = self.run_on_commit
908 self.run_on_commit = []
909 while current_run_on_commit:
910 _, func, robust = current_run_on_commit.pop(0)
911 if robust:
912 try:
913 func()
914 except Exception as e:
915 logger.error(
916 f"Error calling {func.__qualname__} in on_commit() during "
917 f"transaction (%s).",
918 e,
919 exc_info=True,
920 )
921 else:
922 func()
923
924 @contextmanager
925 def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
926 """
927 Return a context manager under which the wrapper is applied to suitable
928 database query executions.
929 """
930 self.execute_wrappers.append(wrapper)
931 try:
932 yield
933 finally:
934 self.execute_wrappers.pop()
935
936 # ##### SQL generation methods that require connection state #####
937
938 def compose_sql(self, query: str, params: Any) -> str:
939 """
940 Compose a SQL query with parameters using psycopg's mogrify.
941
942 This requires an active connection because it uses the connection's
943 cursor to properly format parameters.
944 """
945 assert self.connection is not None
946 return ClientCursor(self.connection).mogrify(
947 psycopg_sql.SQL(cast(LiteralString, query)), params
948 )
949
950 def last_executed_query(
951 self,
952 cursor: utils.CursorWrapper,
953 sql: str,
954 params: Any,
955 ) -> str | None:
956 """
957 Return a string of the query last executed by the given cursor, with
958 placeholders replaced with actual values.
959 """
960 try:
961 return self.compose_sql(sql, params)
962 except errors.DataError:
963 return None
964
965 def unification_cast_sql(self, output_field: Field) -> str:
966 """
967 Given a field instance, return the SQL that casts the result of a union
968 to that type. The resulting string should contain a '%s' placeholder
969 for the expression being cast.
970 """
971 internal_type = output_field.get_internal_type()
972 if internal_type in (
973 "GenericIPAddressField",
974 "TimeField",
975 "UUIDField",
976 ):
977 # PostgreSQL will resolve a union as type 'text' if input types are
978 # 'unknown'.
979 # https://www.postgresql.org/docs/current/typeconv-union-case.html
980 # These fields cannot be implicitly cast back in the default
981 # PostgreSQL configuration so we need to explicitly cast them.
982 # We must also remove components of the type within brackets:
983 # varchar(255) -> varchar.
984 db_type = output_field.db_type()
985 if db_type:
986 return "CAST(%s AS {})".format(db_type.split("(")[0])
987 return "%s"
988
989 # ##### Introspection methods #####
990
991 def table_names(
992 self, cursor: CursorWrapper | None = None, include_views: bool = False
993 ) -> list[str]:
994 """
995 Return a list of names of all tables that exist in the database.
996 Sort the returned table list by Python's default sorting. Do NOT use
997 the database's ORDER BY here to avoid subtle differences in sorting
998 order between databases.
999 """
1000
1001 def get_names(cursor: CursorWrapper) -> list[str]:
1002 return sorted(
1003 ti.name
1004 for ti in self.get_table_list(cursor)
1005 if include_views or ti.type == "t"
1006 )
1007
1008 if cursor is None:
1009 with self.cursor() as cursor:
1010 return get_names(cursor)
1011 return get_names(cursor)
1012
1013 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
1014 """
1015 Return an unsorted list of TableInfo named tuples of all tables and
1016 views that exist in the database.
1017 """
1018 cursor.execute(
1019 """
1020 SELECT
1021 c.relname,
1022 CASE
1023 WHEN c.relispartition THEN 'p'
1024 WHEN c.relkind IN ('m', 'v') THEN 'v'
1025 ELSE 't'
1026 END,
1027 obj_description(c.oid, 'pg_class')
1028 FROM pg_catalog.pg_class c
1029 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
1030 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
1031 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
1032 AND pg_catalog.pg_table_is_visible(c.oid)
1033 """
1034 )
1035 return [
1036 TableInfo(*row)
1037 for row in cursor.fetchall()
1038 if row[0] not in self.ignored_tables
1039 ]
1040
1041 def plain_table_names(
1042 self, only_existing: bool = False, include_views: bool = True
1043 ) -> list[str]:
1044 """
1045 Return a list of all table names that have associated Plain models and
1046 are in INSTALLED_PACKAGES.
1047
1048 If only_existing is True, include only the tables in the database.
1049 """
1050 tables = set()
1051 for model in get_migratable_models():
1052 tables.add(model.model_options.db_table)
1053 tables.update(
1054 f.m2m_db_table() for f in model._model_meta.local_many_to_many
1055 )
1056 tables = list(tables)
1057 if only_existing:
1058 existing_tables = set(self.table_names(include_views=include_views))
1059 tables = [t for t in tables if t in existing_tables]
1060 return tables
1061
1062 def get_sequences(
1063 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
1064 ) -> list[dict[str, Any]]:
1065 """
1066 Return a list of introspected sequences for table_name. Each sequence
1067 is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
1068 """
1069 cursor.execute(
1070 """
1071 SELECT
1072 s.relname AS sequence_name,
1073 a.attname AS colname
1074 FROM
1075 pg_class s
1076 JOIN pg_depend d ON d.objid = s.oid
1077 AND d.classid = 'pg_class'::regclass
1078 AND d.refclassid = 'pg_class'::regclass
1079 JOIN pg_attribute a ON d.refobjid = a.attrelid
1080 AND d.refobjsubid = a.attnum
1081 JOIN pg_class tbl ON tbl.oid = d.refobjid
1082 AND tbl.relname = %s
1083 AND pg_catalog.pg_table_is_visible(tbl.oid)
1084 WHERE
1085 s.relkind = 'S';
1086 """,
1087 [table_name],
1088 )
1089 return [
1090 {"name": row[0], "table": table_name, "column": row[1]}
1091 for row in cursor.fetchall()
1092 ]
1093
1094 def get_constraints(
1095 self, cursor: CursorWrapper, table_name: str
1096 ) -> dict[str, dict[str, Any]]:
1097 """
1098 Retrieve any constraints or keys (unique, pk, fk, check, index) across
1099 one or more columns. Also retrieve the definition of expression-based
1100 indexes.
1101 """
1102 constraints: dict[str, dict[str, Any]] = {}
1103 # Loop over the key table, collecting things as constraints. The column
1104 # array must return column names in the same order in which they were
1105 # created.
1106 cursor.execute(
1107 """
1108 SELECT
1109 c.conname,
1110 array(
1111 SELECT attname
1112 FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
1113 JOIN pg_attribute AS ca ON cols.colid = ca.attnum
1114 WHERE ca.attrelid = c.conrelid
1115 ORDER BY cols.arridx
1116 ),
1117 c.contype,
1118 (SELECT fkc.relname || '.' || fka.attname
1119 FROM pg_attribute AS fka
1120 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
1121 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
1122 cl.reloptions
1123 FROM pg_constraint AS c
1124 JOIN pg_class AS cl ON c.conrelid = cl.oid
1125 WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
1126 """,
1127 [table_name],
1128 )
1129 for constraint, columns, kind, used_cols, options in cursor.fetchall():
1130 constraints[constraint] = {
1131 "columns": columns,
1132 "primary_key": kind == "p",
1133 "unique": kind in ["p", "u"],
1134 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
1135 "check": kind == "c",
1136 "index": False,
1137 "definition": None,
1138 "options": options,
1139 }
1140 # Now get indexes
1141 cursor.execute(
1142 """
1143 SELECT
1144 indexname,
1145 array_agg(attname ORDER BY arridx),
1146 indisunique,
1147 indisprimary,
1148 array_agg(ordering ORDER BY arridx),
1149 amname,
1150 exprdef,
1151 s2.attoptions
1152 FROM (
1153 SELECT
1154 c2.relname as indexname, idx.*, attr.attname, am.amname,
1155 CASE
1156 WHEN idx.indexprs IS NOT NULL THEN
1157 pg_get_indexdef(idx.indexrelid)
1158 END AS exprdef,
1159 CASE am.amname
1160 WHEN %s THEN
1161 CASE (option & 1)
1162 WHEN 1 THEN 'DESC' ELSE 'ASC'
1163 END
1164 END as ordering,
1165 c2.reloptions as attoptions
1166 FROM (
1167 SELECT *
1168 FROM
1169 pg_index i,
1170 unnest(i.indkey, i.indoption)
1171 WITH ORDINALITY koi(key, option, arridx)
1172 ) idx
1173 LEFT JOIN pg_class c ON idx.indrelid = c.oid
1174 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
1175 LEFT JOIN pg_am am ON c2.relam = am.oid
1176 LEFT JOIN
1177 pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
1178 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
1179 ) s2
1180 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
1181 """,
1182 [self.index_default_access_method, table_name],
1183 )
1184 for (
1185 index,
1186 columns,
1187 unique,
1188 primary,
1189 orders,
1190 type_,
1191 definition,
1192 options,
1193 ) in cursor.fetchall():
1194 if index not in constraints:
1195 basic_index = (
1196 type_ == self.index_default_access_method and options is None
1197 )
1198 constraints[index] = {
1199 "columns": columns if columns != [None] else [],
1200 "orders": orders if orders != [None] else [],
1201 "primary_key": primary,
1202 "unique": unique,
1203 "foreign_key": None,
1204 "check": False,
1205 "index": True,
1206 "type": Index.suffix if basic_index else type_,
1207 "definition": definition,
1208 "options": options,
1209 }
1210 return constraints
1211
1212 # ##### Test database creation methods (merged from DatabaseCreation) #####
1213
1214 def _log(self, msg: str) -> None:
1215 sys.stderr.write(msg + os.linesep)
1216
1217 def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
1218 """
1219 Create a test database, prompting the user for confirmation if the
1220 database already exists. Return the name of the test database created.
1221
1222 If prefix is provided, it will be prepended to the database name
1223 to isolate it from other test databases.
1224 """
1225 from plain.models.cli.migrations import apply
1226
1227 test_database_name = self._get_test_db_name(prefix)
1228
1229 if verbosity >= 1:
1230 self._log(f"Creating test database '{test_database_name}'...")
1231
1232 self._create_test_db(
1233 test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
1234 )
1235
1236 self.close()
1237 settings.POSTGRES_DATABASE = test_database_name
1238 self.settings_dict["DATABASE"] = test_database_name
1239
1240 apply.callback(
1241 package_label=None,
1242 migration_name=None,
1243 fake=False,
1244 plan=False,
1245 check_unapplied=False,
1246 backup=False,
1247 no_input=True,
1248 atomic_batch=False, # No need for atomic batch when creating test database
1249 quiet=verbosity < 2, # Show migration output when verbosity is 2+
1250 )
1251
1252 # Ensure a connection for the side effect of initializing the test database.
1253 self.ensure_connection()
1254
1255 return test_database_name
1256
1257 def _get_test_db_name(self, prefix: str = "") -> str:
1258 """
1259 Internal implementation - return the name of the test DB that will be
1260 created. Only useful when called from create_test_db() and
1261 _create_test_db() and when no external munging is done with the 'DATABASE'
1262 settings.
1263
1264 If prefix is provided, it will be prepended to the database name.
1265 """
1266 # Determine the base name: explicit TEST.DATABASE overrides base DATABASE.
1267 base_name = (
1268 self.settings_dict["TEST"]["DATABASE"] or self.settings_dict["DATABASE"]
1269 )
1270 if prefix:
1271 return f"{prefix}_{base_name}"
1272 if self.settings_dict["TEST"]["DATABASE"]:
1273 return self.settings_dict["TEST"]["DATABASE"]
1274 name = self.settings_dict["DATABASE"]
1275 if name is None:
1276 raise ValueError("POSTGRES_DATABASE must be set")
1277 return TEST_DATABASE_PREFIX + name
1278
1279 def _get_database_create_suffix(
1280 self, encoding: str | None = None, template: str | None = None
1281 ) -> str:
1282 """Return PostgreSQL-specific CREATE DATABASE suffix."""
1283 suffix = ""
1284 if encoding:
1285 suffix += f" ENCODING '{encoding}'"
1286 if template:
1287 suffix += f" TEMPLATE {quote_name(template)}"
1288 return suffix and "WITH" + suffix
1289
1290 def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
1291 try:
1292 cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
1293 except Exception as e:
1294 cause = e.__cause__
1295 if cause and not isinstance(cause, errors.DuplicateDatabase):
1296 # All errors except "database already exists" cancel tests.
1297 self._log(f"Got an error creating the test database: {e}")
1298 sys.exit(2)
1299 else:
1300 raise
1301
1302 def _create_test_db(
1303 self, *, test_database_name: str, verbosity: int, autoclobber: bool
1304 ) -> str:
1305 """
1306 Internal implementation - create the test db tables.
1307 """
1308 test_db_params = {
1309 "dbname": quote_name(test_database_name),
1310 "suffix": self.sql_table_creation_suffix(),
1311 }
1312 # Create the test database and connect to it.
1313 with self._nodb_cursor() as cursor:
1314 try:
1315 self._execute_create_test_db(cursor, test_db_params)
1316 except Exception as e:
1317 self._log(f"Got an error creating the test database: {e}")
1318 if not autoclobber:
1319 confirm = input(
1320 "Type 'yes' if you would like to try deleting the test "
1321 f"database '{test_database_name}', or 'no' to cancel: "
1322 )
1323 if autoclobber or confirm == "yes":
1324 try:
1325 if verbosity >= 1:
1326 self._log(
1327 f"Destroying old test database '{test_database_name}'..."
1328 )
1329 cursor.execute(
1330 "DROP DATABASE {dbname}".format(**test_db_params)
1331 )
1332 self._execute_create_test_db(cursor, test_db_params)
1333 except Exception as e:
1334 self._log(f"Got an error recreating the test database: {e}")
1335 sys.exit(2)
1336 else:
1337 self._log("Tests cancelled.")
1338 sys.exit(1)
1339
1340 return test_database_name
1341
1342 def destroy_test_db(
1343 self, old_database_name: str | None = None, verbosity: int = 1
1344 ) -> None:
1345 """
1346 Destroy a test database, prompting the user for confirmation if the
1347 database already exists.
1348 """
1349 self.close()
1350
1351 test_database_name = self.settings_dict["DATABASE"]
1352 if test_database_name is None:
1353 raise ValueError("Test POSTGRES_DATABASE must be set")
1354
1355 if verbosity >= 1:
1356 self._log(f"Destroying test database '{test_database_name}'...")
1357 self._destroy_test_db(test_database_name, verbosity)
1358
1359 # Restore the original database name
1360 if old_database_name is not None:
1361 settings.POSTGRES_DATABASE = old_database_name
1362 self.settings_dict["DATABASE"] = old_database_name
1363
1364 def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
1365 """
1366 Internal implementation - remove the test db tables.
1367 """
1368 # Remove the test database to clean up after
1369 # ourselves. Connect to the previous database (not the test database)
1370 # to do so, because it's not allowed to delete a database while being
1371 # connected to it.
1372 with self._nodb_cursor() as cursor:
1373 cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
1374
1375 def sql_table_creation_suffix(self) -> str:
1376 """
1377 SQL to append to the end of the test table creation statements.
1378 """
1379 test_settings = self.settings_dict["TEST"]
1380 return self._get_database_create_suffix(
1381 encoding=test_settings.get("CHARSET"),
1382 template=test_settings.get("TEMPLATE"),
1383 )
1384
1385
1386class CursorMixin:
1387 """
1388 A subclass of psycopg cursor implementing callproc.
1389 """
1390
1391 def callproc(
1392 self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
1393 ) -> list[Any] | None:
1394 if not isinstance(name, psycopg_sql.Identifier):
1395 name = psycopg_sql.Identifier(name)
1396
1397 qparts: list[psycopg_sql.Composable] = [
1398 psycopg_sql.SQL("SELECT * FROM "),
1399 name,
1400 psycopg_sql.SQL("("),
1401 ]
1402 if args:
1403 for item in args:
1404 qparts.append(psycopg_sql.Literal(item))
1405 qparts.append(psycopg_sql.SQL(","))
1406 del qparts[-1]
1407
1408 qparts.append(psycopg_sql.SQL(")"))
1409 stmt = psycopg_sql.Composed(qparts)
1410 self.execute(stmt) # type: ignore[attr-defined]
1411 return args
1412
1413
1414class ServerBindingCursor(CursorMixin, Database.Cursor):
1415 pass
1416
1417
1418class Cursor(CursorMixin, Database.ClientCursor):
1419 pass
1420
1421
1422class CursorDebugWrapper(BaseCursorDebugWrapper):
1423 def copy(self, statement: Any) -> Any:
1424 with self.debug_sql(statement):
1425 return self.cursor.copy(statement)