1from __future__ import annotations
2
3import _thread
4import warnings
5from collections import deque
6from collections.abc import Generator, Sequence
7from contextlib import contextmanager
8from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
9
10import psycopg
11from psycopg import errors
12from psycopg import sql as psycopg_sql
13
14from plain.logs import get_framework_logger
15from plain.postgres import utils
16from plain.postgres.dialect import quote_name
17from plain.postgres.fields import GenericIPAddressField, TimeField, UUIDField
18from plain.postgres.schema import DatabaseSchemaEditor
19from plain.postgres.sources import ConnectionSource
20from plain.postgres.transaction import TransactionManagementError
21from plain.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
22from plain.postgres.utils import CursorWrapper, debug_transaction
23from plain.runtime import settings
24
25if TYPE_CHECKING:
26 from psycopg import Connection as PsycopgConnection
27
28 from plain.postgres.database_url import DatabaseConfig
29 from plain.postgres.fields import Field
30
31logger = get_framework_logger()
32
33
34def get_migratable_models() -> Generator[Any]:
35 """Return all models that should be included in migrations."""
36 from plain.packages import packages_registry
37 from plain.postgres import models_registry
38
39 return (
40 model
41 for package_config in packages_registry.get_package_configs()
42 for model in models_registry.get_models(
43 package_label=package_config.package_label
44 )
45 )
46
47
48class TableInfo(NamedTuple):
49 """Structure returned by DatabaseConnection.get_table_list()."""
50
51 name: str
52 type: str
53 comment: str | None
54
55
56class DatabaseConnection:
57 """
58 PostgreSQL database connection.
59
60 This is the only database backend supported by Plain.
61 """
62
63 queries_limit: int = 9000
64
65 ignored_tables: list[str] = []
66
67 def __init__(self, source: ConnectionSource):
68 # Lazy — acquired on first use via self._source.
69 self.connection: PsycopgConnection[Any] | None = None
70 self._source: ConnectionSource = source
71 # Query logging in debug mode or when explicitly enabled.
72 self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
73 self.force_debug_cursor: bool = False
74
75 # Transaction related attributes.
76 # Tracks if the connection is in autocommit mode. Per PEP 249, by
77 # default, it isn't.
78 self.autocommit: bool = False
79 # Tracks if the connection is in a transaction managed by 'atomic'.
80 self.in_atomic_block: bool = False
81 # Increment to generate unique savepoint ids.
82 self.savepoint_state: int = 0
83 # List of savepoints created by 'atomic'.
84 self.savepoint_ids: list[str | None] = []
85 # Stack of active 'atomic' blocks.
86 self.atomic_blocks: list[Any] = []
87 # Tracks if the transaction should be rolled back to the next
88 # available savepoint because of an exception in an inner block.
89 self.needs_rollback: bool = False
90 self.rollback_exc: Exception | None = None
91
92 # A list of no-argument functions to run when the transaction commits.
93 # Each entry is an (sids, func, robust) tuple, where sids is a set of
94 # the active savepoint IDs when this function was registered and robust
95 # specifies whether it's allowed for the function to fail.
96 self.run_on_commit: list[tuple[set[str | None], Any, bool]] = []
97
98 # Should we run the on-commit hooks the next time set_autocommit(True)
99 # is called?
100 self.run_commit_hooks_on_set_autocommit_on: bool = False
101
102 # A stack of wrappers to be invoked around execute()/executemany()
103 # calls. Each entry is a function taking five arguments: execute, sql,
104 # params, many, and context. It's the function's responsibility to
105 # call execute(sql, params, many, context).
106 self.execute_wrappers: list[Any] = []
107
108 def __repr__(self) -> str:
109 return f"<{self.__class__.__qualname__} vendor='postgresql'>"
110
111 def __del__(self) -> None:
112 # Safety net for wrappers GC'd without an explicit close() —
113 # e.g. inside a short-lived `asyncio.to_thread` context copy.
114 # Returns the pooled connection to the pool. Guards handle
115 # interpreter shutdown, when attrs may already be cleared.
116 conn = getattr(self, "connection", None)
117 if conn is None:
118 return
119 source = getattr(self, "_source", None)
120 if source is None:
121 return
122 try:
123 source.release(conn)
124 except Exception:
125 pass
126
127 @property
128 def settings_dict(self) -> DatabaseConfig:
129 """Config of the server this wrapper talks to. For pool-backed
130 wrappers this always reflects the live `POSTGRES_URL`."""
131 return self._source.config
132
133 @property
134 def queries_logged(self) -> bool:
135 return self.force_debug_cursor or settings.DEBUG
136
137 @property
138 def queries(self) -> list[dict[str, Any]]:
139 if len(self.queries_log) == self.queries_log.maxlen:
140 warnings.warn(
141 f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
142 "will be returned."
143 )
144 return list(self.queries_log)
145
146 # ##### Connection and cursor methods #####
147
148 def _set_autocommit(self, autocommit: bool) -> None:
149 """Backend-specific implementation to enable or disable autocommit."""
150 assert self.connection is not None
151 self.connection.autocommit = autocommit
152
153 def check_constraints(self, table_names: list[str] | None = None) -> None:
154 """
155 Check constraints by setting them to immediate. Return them to deferred
156 afterward.
157 """
158 with self.cursor() as cursor:
159 cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
160 cursor.execute("SET CONSTRAINTS ALL DEFERRED")
161
162 def make_debug_cursor(self, cursor: psycopg.Cursor[Any]) -> CursorDebugWrapper:
163 return CursorDebugWrapper(cursor, self)
164
165 # ##### Connection lifecycle #####
166
167 def connect(self) -> None:
168 """Connect to the database. Assume that the connection is closed."""
169 self.connection = self._source.acquire()
170 self.set_autocommit(True)
171
172 def ensure_connection(self) -> None:
173 """Guarantee that a connection to the database is established."""
174 if self.connection is None:
175 self.connect()
176
177 # ##### PEP-249 connection method wrappers #####
178
179 def _prepare_cursor(self, cursor: psycopg.Cursor[Any]) -> utils.CursorWrapper:
180 """
181 Validate the connection is usable and perform database cursor wrapping.
182 """
183 if self.queries_logged:
184 wrapped_cursor = self.make_debug_cursor(cursor)
185 else:
186 wrapped_cursor = self.make_cursor(cursor)
187 return wrapped_cursor
188
189 def _cursor(self) -> utils.CursorWrapper:
190 self.ensure_connection()
191 assert self.connection is not None
192 return self._prepare_cursor(self.connection.cursor())
193
194 def _commit(self) -> None:
195 if self.connection is not None:
196 with debug_transaction(self, "COMMIT"):
197 return self.connection.commit()
198
199 def _rollback(self) -> None:
200 if self.connection is not None:
201 with debug_transaction(self, "ROLLBACK"):
202 return self.connection.rollback()
203
204 # ##### Generic wrappers for PEP-249 connection methods #####
205
206 def cursor(self) -> utils.CursorWrapper:
207 """Create a cursor, opening a connection if necessary."""
208 return self._cursor()
209
210 def commit(self) -> None:
211 """Commit a transaction and reset the dirty flag."""
212 self.validate_no_atomic_block()
213 self._commit()
214 self.run_commit_hooks_on_set_autocommit_on = True
215
216 def rollback(self) -> None:
217 """Roll back a transaction and reset the dirty flag."""
218 self.validate_no_atomic_block()
219 self._rollback()
220 self.needs_rollback = False
221 self.run_on_commit = []
222
223 def close(self) -> None:
224 """Close the connection to the database."""
225 # Closing mid-atomic would reopen a fresh autocommit connection on
226 # the next cursor() and silently run the rest of the block outside
227 # its transaction. Callers that drop a connection during error
228 # recovery (see Atomic.__exit__) unwind the atomic state first.
229 self.validate_no_atomic_block()
230
231 self.run_on_commit = []
232 if self.connection is None:
233 return
234 try:
235 self._source.release(self.connection)
236 finally:
237 # Null the reference so __del__ (and ensure_connection) can't
238 # touch an already-released psycopg connection.
239 self.connection = None
240
241 # ##### Savepoint management #####
242
243 def _savepoint(self, sid: str) -> None:
244 with self.cursor() as cursor:
245 cursor.execute(f"SAVEPOINT {quote_name(sid)}")
246
247 def _savepoint_rollback(self, sid: str) -> None:
248 with self.cursor() as cursor:
249 cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
250
251 def _savepoint_commit(self, sid: str) -> None:
252 with self.cursor() as cursor:
253 cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
254
255 # ##### Generic savepoint management methods #####
256
257 def savepoint(self) -> str | None:
258 """
259 Create a savepoint inside the current transaction. Return an
260 identifier for the savepoint that will be used for the subsequent
261 rollback or commit. Return None if in autocommit mode (no transaction).
262 """
263 if self.get_autocommit():
264 return None
265
266 thread_ident = _thread.get_ident()
267 tid = str(thread_ident).replace("-", "")
268
269 self.savepoint_state += 1
270 sid = "s%s_x%d" % (tid, self.savepoint_state) # noqa: UP031
271
272 self._savepoint(sid)
273
274 return sid
275
276 def savepoint_rollback(self, sid: str) -> None:
277 """
278 Roll back to a savepoint. Do nothing if in autocommit mode.
279 """
280 if self.get_autocommit():
281 return
282
283 self._savepoint_rollback(sid)
284
285 # Remove any callbacks registered while this savepoint was active.
286 self.run_on_commit = [
287 (sids, func, robust)
288 for (sids, func, robust) in self.run_on_commit
289 if sid not in sids
290 ]
291
292 def savepoint_commit(self, sid: str) -> None:
293 """
294 Release a savepoint. Do nothing if in autocommit mode.
295 """
296 if self.get_autocommit():
297 return
298
299 self._savepoint_commit(sid)
300
301 def clean_savepoints(self) -> None:
302 """
303 Reset the counter used to generate unique savepoint ids in this thread.
304 """
305 self.savepoint_state = 0
306
307 # ##### Generic transaction management methods #####
308
309 def get_autocommit(self) -> bool:
310 """Get the autocommit state."""
311 self.ensure_connection()
312 return self.autocommit
313
314 def set_autocommit(self, autocommit: bool) -> None:
315 """
316 Enable or disable autocommit.
317
318 Used internally by atomic() to manage transactions. Don't call this
319 directly — use atomic() instead.
320 """
321 self.validate_no_atomic_block()
322 self.ensure_connection()
323
324 if autocommit:
325 self._set_autocommit(autocommit)
326 else:
327 with debug_transaction(self, "BEGIN"):
328 self._set_autocommit(autocommit)
329 self.autocommit = autocommit
330
331 if autocommit and self.run_commit_hooks_on_set_autocommit_on:
332 self.run_and_clear_commit_hooks()
333 self.run_commit_hooks_on_set_autocommit_on = False
334
335 def get_rollback(self) -> bool:
336 """Get the "needs rollback" flag -- for *advanced use* only."""
337 if not self.in_atomic_block:
338 raise TransactionManagementError(
339 "The rollback flag doesn't work outside of an 'atomic' block."
340 )
341 return self.needs_rollback
342
343 def set_rollback(self, rollback: bool) -> None:
344 """
345 Set or unset the "needs rollback" flag -- for *advanced use* only.
346 """
347 if not self.in_atomic_block:
348 raise TransactionManagementError(
349 "The rollback flag doesn't work outside of an 'atomic' block."
350 )
351 self.needs_rollback = rollback
352
353 def validate_no_atomic_block(self) -> None:
354 """Raise an error if an atomic block is active."""
355 if self.in_atomic_block:
356 raise TransactionManagementError(
357 "This is forbidden when an 'atomic' block is active."
358 )
359
360 def validate_no_broken_transaction(self) -> None:
361 if self.needs_rollback:
362 raise TransactionManagementError(
363 "An error occurred in the current transaction. You can't "
364 "execute queries until the end of the 'atomic' block."
365 ) from self.rollback_exc
366
367 # ##### Miscellaneous #####
368
369 def make_cursor(self, cursor: psycopg.Cursor[Any]) -> utils.CursorWrapper:
370 """Create a cursor without debug logging."""
371 return utils.CursorWrapper(cursor, self)
372
373 def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
374 """Return a new instance of the schema editor."""
375 return DatabaseSchemaEditor(self, *args, **kwargs)
376
377 def on_commit(self, func: Any, robust: bool = False) -> None:
378 if not callable(func):
379 raise TypeError("on_commit()'s callback must be a callable.")
380 if self.in_atomic_block:
381 # Transaction in progress; save for execution on commit.
382 self.run_on_commit.append((set(self.savepoint_ids), func, robust))
383 else:
384 # No transaction in progress; execute immediately.
385 if robust:
386 try:
387 func()
388 except Exception as e:
389 logger.error(
390 "Error calling on_commit() handler",
391 exc_info=True,
392 extra={"handler": func.__qualname__, "error": str(e)},
393 )
394 else:
395 func()
396
397 def run_and_clear_commit_hooks(self) -> None:
398 self.validate_no_atomic_block()
399 current_run_on_commit = self.run_on_commit
400 self.run_on_commit = []
401 while current_run_on_commit:
402 _, func, robust = current_run_on_commit.pop(0)
403 if robust:
404 try:
405 func()
406 except Exception as e:
407 logger.error(
408 "Error calling on_commit() handler during transaction",
409 exc_info=True,
410 extra={"handler": func.__qualname__, "error": str(e)},
411 )
412 else:
413 func()
414
415 @contextmanager
416 def execute_wrapper(self, wrapper: Any) -> Generator[None]:
417 """
418 Return a context manager under which the wrapper is applied to suitable
419 database query executions.
420 """
421 self.execute_wrappers.append(wrapper)
422 try:
423 yield
424 finally:
425 self.execute_wrappers.pop()
426
427 # ##### SQL generation methods that require connection state #####
428
429 def compose_sql(self, query: str, params: Any) -> str:
430 """
431 Compose a SQL query with parameters using psycopg's mogrify.
432
433 This requires an active connection because it uses the connection's
434 cursor to properly format parameters.
435 """
436 assert self.connection is not None
437 return psycopg.ClientCursor(self.connection).mogrify(
438 psycopg_sql.SQL(cast(LiteralString, query)), params
439 )
440
441 def last_executed_query(
442 self,
443 cursor: utils.CursorWrapper,
444 sql: str,
445 params: Any,
446 ) -> str | None:
447 """
448 Return a string of the query last executed by the given cursor, with
449 placeholders replaced with actual values.
450 """
451 try:
452 return self.compose_sql(sql, params)
453 except errors.DataError:
454 return None
455
456 def unification_cast_sql(self, output_field: Field) -> str:
457 """
458 Given a field instance, return the SQL that casts the result of a union
459 to that type. The resulting string should contain a '%s' placeholder
460 for the expression being cast.
461 """
462 if isinstance(output_field, GenericIPAddressField | TimeField | UUIDField):
463 # PostgreSQL will resolve a union as type 'text' if input types are
464 # 'unknown'.
465 # https://www.postgresql.org/docs/current/typeconv-union-case.html
466 # These fields cannot be implicitly cast back in the default
467 # PostgreSQL configuration so we need to explicitly cast them.
468 # We must also remove components of the type within brackets:
469 # varchar(255) -> varchar.
470 db_type = output_field.db_type()
471 if db_type:
472 return "CAST(%s AS {})".format(db_type.split("(")[0])
473 return "%s"
474
475 # ##### Introspection methods #####
476
477 def table_names(
478 self, cursor: CursorWrapper | None = None, include_views: bool = False
479 ) -> list[str]:
480 """
481 Return a list of names of all tables that exist in the database.
482 Sort the returned table list by Python's default sorting. Do NOT use
483 the database's ORDER BY here to avoid subtle differences in sorting
484 order between databases.
485 """
486
487 def get_names(cursor: CursorWrapper) -> list[str]:
488 return sorted(
489 ti.name
490 for ti in self.get_table_list(cursor)
491 if include_views or ti.type == "t"
492 )
493
494 if cursor is None:
495 with self.cursor() as cursor:
496 return get_names(cursor)
497 return get_names(cursor)
498
499 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
500 """
501 Return an unsorted list of TableInfo named tuples of all tables and
502 views that exist in the database.
503 """
504 cursor.execute(
505 """
506 SELECT
507 c.relname,
508 CASE
509 WHEN c.relispartition THEN 'p'
510 WHEN c.relkind IN ('m', 'v') THEN 'v'
511 ELSE 't'
512 END,
513 obj_description(c.oid, 'pg_class')
514 FROM pg_catalog.pg_class c
515 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
516 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
517 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
518 AND pg_catalog.pg_table_is_visible(c.oid)
519 """
520 )
521 return [
522 TableInfo(*row)
523 for row in cursor.fetchall()
524 if row[0] not in self.ignored_tables
525 ]
526
527 def plain_table_names(
528 self, only_existing: bool = False, include_views: bool = True
529 ) -> list[str]:
530 """
531 Return a list of all table names that have associated Plain models and
532 are in INSTALLED_PACKAGES.
533
534 If only_existing is True, include only the tables in the database.
535 """
536 tables = set()
537 for model in get_migratable_models():
538 tables.add(model.model_options.db_table)
539 tables.update(
540 f.m2m_db_table() for f in model._model_meta.local_many_to_many
541 )
542 tables = list(tables)
543 if only_existing:
544 existing_tables = set(self.table_names(include_views=include_views))
545 tables = [t for t in tables if t in existing_tables]
546 return tables
547
548 def get_sequences(
549 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
550 ) -> list[dict[str, Any]]:
551 """
552 Return a list of introspected sequences for table_name. Each sequence
553 is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
554 """
555 cursor.execute(
556 """
557 SELECT
558 s.relname AS sequence_name,
559 a.attname AS colname
560 FROM
561 pg_class s
562 JOIN pg_depend d ON d.objid = s.oid
563 AND d.classid = 'pg_class'::regclass
564 AND d.refclassid = 'pg_class'::regclass
565 JOIN pg_attribute a ON d.refobjid = a.attrelid
566 AND d.refobjsubid = a.attnum
567 JOIN pg_class tbl ON tbl.oid = d.refobjid
568 AND tbl.relname = %s
569 AND pg_catalog.pg_table_is_visible(tbl.oid)
570 WHERE
571 s.relkind = 'S';
572 """,
573 [table_name],
574 )
575 return [
576 {"name": row[0], "table": table_name, "column": row[1]}
577 for row in cursor.fetchall()
578 ]
579
580 def get_constraints(
581 self, cursor: CursorWrapper, table_name: str
582 ) -> dict[str, dict[str, Any]]:
583 """
584 Retrieve any constraints or keys (unique, pk, fk, check, index) across
585 one or more columns. Also retrieve the definition of expression-based
586 indexes.
587 """
588 constraints: dict[str, dict[str, Any]] = {}
589 # Loop over the key table, collecting things as constraints. The column
590 # array must return column names in the same order in which they were
591 # created.
592 cursor.execute(
593 """
594 SELECT
595 c.conname,
596 array(
597 SELECT attname
598 FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
599 JOIN pg_attribute AS ca ON cols.colid = ca.attnum
600 WHERE ca.attrelid = c.conrelid
601 ORDER BY cols.arridx
602 ),
603 c.contype,
604 (SELECT fkc.relname || '.' || fka.attname
605 FROM pg_attribute AS fka
606 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
607 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
608 c.convalidated,
609 pg_get_constraintdef(c.oid),
610 c.confdeltype
611 FROM pg_constraint AS c
612 JOIN pg_class AS cl ON c.conrelid = cl.oid
613 WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
614 """,
615 [table_name],
616 )
617 for (
618 constraint,
619 columns,
620 kind,
621 used_cols,
622 validated,
623 constraintdef,
624 confdeltype,
625 ) in cursor.fetchall():
626 constraints[constraint] = {
627 "columns": columns,
628 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
629 "contype": kind,
630 "index": False,
631 "definition": constraintdef,
632 "validated": validated,
633 "on_delete_action": confdeltype if kind == "f" else None,
634 }
635 # Now get indexes. Sort order, opclasses, INCLUDE, and predicates all
636 # ride along inside `pg_get_indexdef` and are compared via the
637 # canonical-tail round-trip in convergence — no need to introspect
638 # them here as separate columns.
639 cursor.execute(
640 """
641 SELECT
642 indexname,
643 array_agg(attname ORDER BY arridx),
644 indisunique,
645 amname,
646 exprdef,
647 indisvalid
648 FROM (
649 SELECT
650 c2.relname as indexname, idx.*, attr.attname, am.amname,
651 pg_get_indexdef(idx.indexrelid) AS exprdef
652 FROM (
653 SELECT *
654 FROM
655 pg_index i,
656 unnest(i.indkey)
657 WITH ORDINALITY koi(key, arridx)
658 ) idx
659 LEFT JOIN pg_class c ON idx.indrelid = c.oid
660 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
661 LEFT JOIN pg_am am ON c2.relam = am.oid
662 LEFT JOIN
663 pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
664 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
665 ) s2
666 GROUP BY
667 indexname, indisunique, amname, exprdef, indisvalid;
668 """,
669 [table_name],
670 )
671 for (
672 index,
673 columns,
674 unique,
675 type_,
676 definition,
677 valid,
678 ) in cursor.fetchall():
679 if index not in constraints:
680 constraints[index] = {
681 "columns": columns if columns != [None] else [],
682 "unique": unique,
683 "index": True,
684 "type": type_,
685 "definition": definition,
686 "valid": valid,
687 }
688 return constraints
689
690
691class CursorDebugWrapper(BaseCursorDebugWrapper):
692 def copy(self, statement: Any) -> Any:
693 with self.debug_sql(statement):
694 return self.cursor.copy(statement)